Interruptible Async Model Loading With Progress Monitoringpull/692/merge
| @@ -19,7 +19,7 @@ public class BatchedExecutorFork | |||||
| string modelPath = UserSettings.GetModelPath(); | string modelPath = UserSettings.GetModelPath(); | ||||
| var parameters = new ModelParams(modelPath); | var parameters = new ModelParams(modelPath); | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | ||||
| @@ -19,7 +19,7 @@ public class BatchedExecutorGuidance | |||||
| string modelPath = UserSettings.GetModelPath(); | string modelPath = UserSettings.GetModelPath(); | ||||
| var parameters = new ModelParams(modelPath); | var parameters = new ModelParams(modelPath); | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim(); | var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim(); | ||||
| var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim(); | var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim(); | ||||
| @@ -20,7 +20,7 @@ public class BatchedExecutorRewind | |||||
| string modelPath = UserSettings.GetModelPath(); | string modelPath = UserSettings.GetModelPath(); | ||||
| var parameters = new ModelParams(modelPath); | var parameters = new ModelParams(modelPath); | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | ||||
| @@ -18,7 +18,7 @@ public class BatchedExecutorSaveAndLoad | |||||
| string modelPath = UserSettings.GetModelPath(); | string modelPath = UserSettings.GetModelPath(); | ||||
| var parameters = new ModelParams(modelPath); | var parameters = new ModelParams(modelPath); | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | ||||
| @@ -31,7 +31,7 @@ public class ChatChineseGB2312 | |||||
| GpuLayerCount = 5, | GpuLayerCount = 5, | ||||
| Encoding = Encoding.UTF8 | Encoding = Encoding.UTF8 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var executor = new InteractiveExecutor(context); | var executor = new InteractiveExecutor(context); | ||||
| @@ -15,11 +15,11 @@ public class ChatSessionStripRoleName | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var executor = new InteractiveExecutor(context); | var executor = new InteractiveExecutor(context); | ||||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); | |||||
| var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json"); | |||||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | ||||
| ChatSession session = new(executor, chatHistory); | ChatSession session = new(executor, chatHistory); | ||||
| @@ -13,7 +13,7 @@ public class ChatSessionWithHistory | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var executor = new InteractiveExecutor(context); | var executor = new InteractiveExecutor(context); | ||||
| @@ -13,11 +13,11 @@ public class ChatSessionWithRestart | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var executor = new InteractiveExecutor(context); | var executor = new InteractiveExecutor(context); | ||||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); | |||||
| var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json"); | |||||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | ||||
| ChatSession prototypeSession = | ChatSession prototypeSession = | ||||
| await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); | await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); | ||||
| @@ -13,11 +13,11 @@ public class ChatSessionWithRoleName | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var executor = new InteractiveExecutor(context); | var executor = new InteractiveExecutor(context); | ||||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); | |||||
| var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json"); | |||||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | ||||
| ChatSession session = new(executor, chatHistory); | ChatSession session = new(executor, chatHistory); | ||||
| @@ -29,7 +29,7 @@ | |||||
| { | { | ||||
| ContextSize = 4096 | ContextSize = 4096 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null); | var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null); | ||||
| @@ -9,7 +9,7 @@ namespace LLama.Examples.Examples | |||||
| { | { | ||||
| string modelPath = UserSettings.GetModelPath(); | string modelPath = UserSettings.GetModelPath(); | ||||
| var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); | |||||
| var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim(); | |||||
| var grammar = Grammar.Parse(gbnf, "root"); | var grammar = Grammar.Parse(gbnf, "root"); | ||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| @@ -17,7 +17,7 @@ namespace LLama.Examples.Examples | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var ex = new StatelessExecutor(model, parameters); | var ex = new StatelessExecutor(model, parameters); | ||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| @@ -9,14 +9,14 @@ namespace LLama.Examples.Examples | |||||
| { | { | ||||
| string modelPath = UserSettings.GetModelPath(); | string modelPath = UserSettings.GetModelPath(); | ||||
| var prompt = File.ReadAllText("Assets/dan.txt").Trim(); | |||||
| var prompt = (await File.ReadAllTextAsync("Assets/dan.txt")).Trim(); | |||||
| var parameters = new ModelParams(modelPath) | var parameters = new ModelParams(modelPath) | ||||
| { | { | ||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var executor = new InstructExecutor(context); | var executor = new InstructExecutor(context); | ||||
| @@ -16,7 +16,7 @@ namespace LLama.Examples.Examples | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var ex = new InteractiveExecutor(context); | var ex = new InteractiveExecutor(context); | ||||
| @@ -20,7 +20,7 @@ namespace LLama.Examples.Examples | |||||
| var parameters = new ModelParams(modelPath); | var parameters = new ModelParams(modelPath); | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| // Llava Init | // Llava Init | ||||
| @@ -15,7 +15,7 @@ namespace LLama.Examples.Examples | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var ex = new InteractiveExecutor(context); | var ex = new InteractiveExecutor(context); | ||||
| @@ -16,7 +16,7 @@ namespace LLama.Examples.Examples | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| using var context = model.CreateContext(parameters); | using var context = model.CreateContext(parameters); | ||||
| var ex = new InteractiveExecutor(context); | var ex = new InteractiveExecutor(context); | ||||
| @@ -16,7 +16,7 @@ namespace LLama.Examples.Examples | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath); | var parameters = new ModelParams(modelPath); | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var ex = new StatelessExecutor(model, parameters); | var ex = new StatelessExecutor(model, parameters); | ||||
| var chatGPT = new LLamaSharpChatCompletion(ex); | var chatGPT = new LLamaSharpChatCompletion(ex); | ||||
| @@ -23,7 +23,7 @@ namespace LLama.Examples.Examples | |||||
| Embeddings = true | Embeddings = true | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var embedding = new LLamaEmbedder(model, parameters); | var embedding = new LLamaEmbedder(model, parameters); | ||||
| Console.WriteLine("===================================================="); | Console.WriteLine("===================================================="); | ||||
| @@ -19,7 +19,7 @@ namespace LLama.Examples.Examples | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var parameters = new ModelParams(modelPath); | var parameters = new ModelParams(modelPath); | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var ex = new StatelessExecutor(model, parameters); | var ex = new StatelessExecutor(model, parameters); | ||||
| var builder = Kernel.CreateBuilder(); | var builder = Kernel.CreateBuilder(); | ||||
| @@ -14,7 +14,7 @@ namespace LLama.Examples.Examples | |||||
| Seed = 1337, | Seed = 1337, | ||||
| GpuLayerCount = 5 | GpuLayerCount = 5 | ||||
| }; | }; | ||||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||||
| using var model = await LLamaWeights.LoadFromFileAsync(parameters); | |||||
| var ex = new StatelessExecutor(model, parameters); | var ex = new StatelessExecutor(model, parameters); | ||||
| Console.ForegroundColor = ConsoleColor.Yellow; | Console.ForegroundColor = ConsoleColor.Yellow; | ||||
| @@ -12,7 +12,7 @@ namespace LLama.Examples.Examples | |||||
| // Load weights into memory | // Load weights into memory | ||||
| var @params = new ModelParams(modelPath); | var @params = new ModelParams(modelPath); | ||||
| using var weights = LLamaWeights.LoadFromFile(@params); | |||||
| using var weights = await LLamaWeights.LoadFromFileAsync(@params); | |||||
| // Create 2 contexts sharing the same weights | // Create 2 contexts sharing the same weights | ||||
| using var aliceCtx = weights.CreateContext(@params); | using var aliceCtx = weights.CreateContext(@params); | ||||
| @@ -1,7 +1,10 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading; | |||||
| using System.Threading.Tasks; | |||||
| using LLama.Abstractions; | using LLama.Abstractions; | ||||
| using LLama.Exceptions; | |||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| using LLama.Native; | using LLama.Native; | ||||
| using Microsoft.Extensions.Logging; | using Microsoft.Extensions.Logging; | ||||
| @@ -84,6 +87,104 @@ namespace LLama | |||||
| return new LLamaWeights(weights); | return new LLamaWeights(weights); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Load weights into memory | |||||
| /// </summary> | |||||
| /// <param name="params">Parameters to use to load the model</param> | |||||
| /// <param name="token">A cancellation token that can interrupt model loading</param> | |||||
| /// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception> | |||||
| /// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception> | |||||
| public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null) | |||||
| { | |||||
| // don't touch the @params object inside the task, it might be changed | |||||
| // externally! Save a copy of everything that we need later. | |||||
| var modelPath = @params.ModelPath; | |||||
| var loraBase = @params.LoraBase; | |||||
| var loraAdapters = @params.LoraAdapters.ToArray(); | |||||
| // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a | |||||
| // slightly smaller range to allow some space for reporting LoRA loading too. | |||||
| var modelLoadProgressRange = 1f; | |||||
| if (loraAdapters.Length > 0) | |||||
| modelLoadProgressRange = 0.9f; | |||||
| using (@params.ToLlamaModelParams(out var lparams)) | |||||
| { | |||||
| #if !NETSTANDARD2_0 | |||||
| // Overwrite the progress callback with one which polls the cancellation token and updates the progress object | |||||
| if (token.CanBeCanceled || progressReporter != null) | |||||
| { | |||||
| var internalCallback = lparams.progress_callback; | |||||
| lparams.progress_callback = (progress, ctx) => | |||||
| { | |||||
| // Update the progress reporter (remapping the value into the smaller range). | |||||
| progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange); | |||||
| // If the user set a callback in the model params, call that and see if we should cancel | |||||
| if (internalCallback != null && !internalCallback(progress, ctx)) | |||||
| return false; | |||||
| // Check the cancellation token | |||||
| if (token.IsCancellationRequested) | |||||
| return false; | |||||
| return true; | |||||
| }; | |||||
| } | |||||
| #endif | |||||
| var model = await Task.Run(() => | |||||
| { | |||||
| try | |||||
| { | |||||
| // Load the model | |||||
| var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); | |||||
| // Apply the LoRA adapters | |||||
| for (var i = 0; i < loraAdapters.Length; i++) | |||||
| { | |||||
| // Interrupt applying LoRAs if the token is cancelled | |||||
| if (token.IsCancellationRequested) | |||||
| { | |||||
| weights.Dispose(); | |||||
| token.ThrowIfCancellationRequested(); | |||||
| } | |||||
| // Don't apply invalid adapters | |||||
| var adapter = loraAdapters[i]; | |||||
| if (string.IsNullOrEmpty(adapter.Path)) | |||||
| continue; | |||||
| if (adapter.Scale <= 0) | |||||
| continue; | |||||
| weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); | |||||
| // Report progress. Model loading reported progress from 0 -> 0.9, use | |||||
| // the last 0.1 to represent all of the LoRA adapters being applied. | |||||
| progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1)); | |||||
| } | |||||
| // Update progress reporter to indicate completion | |||||
| progressReporter?.Report(1); | |||||
| return new LLamaWeights(weights); | |||||
| } | |||||
| catch (LoadWeightsFailedException) | |||||
| { | |||||
| // Convert a LoadWeightsFailedException into a cancellation exception if possible. | |||||
| token.ThrowIfCancellationRequested(); | |||||
| // Ok the weights failed to load for some reason other than cancellation. | |||||
| throw; | |||||
| } | |||||
| }, token); | |||||
| return model; | |||||
| } | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| @@ -8,6 +8,8 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="progress"></param> | /// <param name="progress"></param> | ||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <returns>If the provided progress_callback returns true, model loading continues. | |||||
| /// If it returns false, model loading is immediately aborted.</returns> | |||||
| /// <remarks>llama_progress_callback</remarks> | /// <remarks>llama_progress_callback</remarks> | ||||
| public delegate bool LlamaProgressCallback(float progress, IntPtr ctx); | public delegate bool LlamaProgressCallback(float progress, IntPtr ctx); | ||||
| @@ -38,7 +38,7 @@ namespace LLama.Native | |||||
| // as NET Framework 4.8 does not play nice with the LlamaProgressCallback type | // as NET Framework 4.8 does not play nice with the LlamaProgressCallback type | ||||
| public IntPtr progress_callback; | public IntPtr progress_callback; | ||||
| #else | #else | ||||
| public LlamaProgressCallback progress_callback; | |||||
| public LlamaProgressCallback? progress_callback; | |||||
| #endif | #endif | ||||
| /// <summary> | /// <summary> | ||||
| @@ -120,8 +120,11 @@ namespace LLama.Native | |||||
| if (!fs.CanRead) | if (!fs.CanRead) | ||||
| throw new InvalidOperationException($"Model file '{modelPath}' is not readable"); | throw new InvalidOperationException($"Model file '{modelPath}' is not readable"); | ||||
| return llama_load_model_from_file(modelPath, lparams) | |||||
| ?? throw new LoadWeightsFailedException(modelPath); | |||||
| var handle = llama_load_model_from_file(modelPath, lparams); | |||||
| if (handle.IsInvalid) | |||||
| throw new LoadWeightsFailedException(modelPath); | |||||
| return handle; | |||||
| } | } | ||||
| #region native API | #region native API | ||||