diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index e37de1e9..ce712b72 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -92,10 +92,11 @@ namespace LLama /// /// Parameters to use to load the model /// A cancellation token that can interrupt model loading + /// Receives progress updates as the model loads (0 to 1) /// /// Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled. /// Thrown if the cancellation token is cancelled. - public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default) + public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress? progressReporter = null) { // don't touch the @params object inside the task, it might be changed // externally! Save a copy of everything that we need later. @@ -103,16 +104,25 @@ namespace LLama var loraBase = @params.LoraBase; var loraAdapters = @params.LoraAdapters.ToArray(); + // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a + // slightly smaller range to allow some space for reporting LoRA loading too. + var modelLoadProgressRange = 1f; + if (loraAdapters.Length > 0) + modelLoadProgressRange = 0.9f; + using (@params.ToLlamaModelParams(out var lparams)) { #if !NETSTANDARD2_0 - // Overwrite the progress callback with one which polls the cancellation token - if (token.CanBeCanceled) + // Overwrite the progress callback with one which polls the cancellation token and updates the progress object + if (token.CanBeCanceled || progressReporter != null) { var internalCallback = lparams.progress_callback; lparams.progress_callback = (progress, ctx) => { - // If the user set a call in the model params, first call that and see if we should cancel + // Update the progress reporter (remapping the value into the smaller range). + progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange); + + // If the user set a callback in the model params, call that and see if we should cancel if (internalCallback != null && !internalCallback(progress, ctx)) return false; @@ -129,8 +139,11 @@ namespace LLama { try { + // Load the model var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); - foreach (var adapter in loraAdapters) + + // Apply the LoRA adapters + for (var i = 0; i < loraAdapters.Length; i++) { // Interrupt applying LoRAs if the token is cancelled if (token.IsCancellationRequested) @@ -140,14 +153,22 @@ namespace LLama } // Don't apply invalid adapters + var adapter = loraAdapters[i]; if (string.IsNullOrEmpty(adapter.Path)) continue; if (adapter.Scale <= 0) continue; weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); + + // Report progress. Model loading reported progress from 0 -> 0.9, use + // the last 0.1 to represent all of the LoRA adapters being applied. + progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1)); } + // Update progress reporter to indicate completion + progressReporter?.Report(1); + return new LLamaWeights(weights); } catch (LoadWeightsFailedException)