| @@ -92,10 +92,11 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="params">Parameters to use to load the model</param> | /// <param name="params">Parameters to use to load the model</param> | ||||
| /// <param name="token">A cancellation token that can interrupt model loading</param> | /// <param name="token">A cancellation token that can interrupt model loading</param> | ||||
| /// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception> | /// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception> | ||||
| /// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception> | /// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception> | ||||
| public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default) | |||||
| public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null) | |||||
| { | { | ||||
| // don't touch the @params object inside the task, it might be changed | // don't touch the @params object inside the task, it might be changed | ||||
| // externally! Save a copy of everything that we need later. | // externally! Save a copy of everything that we need later. | ||||
| @@ -103,16 +104,25 @@ namespace LLama | |||||
| var loraBase = @params.LoraBase; | var loraBase = @params.LoraBase; | ||||
| var loraAdapters = @params.LoraAdapters.ToArray(); | var loraAdapters = @params.LoraAdapters.ToArray(); | ||||
| // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a | |||||
| // slightly smaller range to allow some space for reporting LoRA loading too. | |||||
| var modelLoadProgressRange = 1f; | |||||
| if (loraAdapters.Length > 0) | |||||
| modelLoadProgressRange = 0.9f; | |||||
| using (@params.ToLlamaModelParams(out var lparams)) | using (@params.ToLlamaModelParams(out var lparams)) | ||||
| { | { | ||||
| #if !NETSTANDARD2_0 | #if !NETSTANDARD2_0 | ||||
| // Overwrite the progress callback with one which polls the cancellation token | |||||
| if (token.CanBeCanceled) | |||||
| // Overwrite the progress callback with one which polls the cancellation token and updates the progress object | |||||
| if (token.CanBeCanceled || progressReporter != null) | |||||
| { | { | ||||
| var internalCallback = lparams.progress_callback; | var internalCallback = lparams.progress_callback; | ||||
| lparams.progress_callback = (progress, ctx) => | lparams.progress_callback = (progress, ctx) => | ||||
| { | { | ||||
| // If the user set a call in the model params, first call that and see if we should cancel | |||||
| // Update the progress reporter (remapping the value into the smaller range). | |||||
| progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange); | |||||
| // If the user set a callback in the model params, call that and see if we should cancel | |||||
| if (internalCallback != null && !internalCallback(progress, ctx)) | if (internalCallback != null && !internalCallback(progress, ctx)) | ||||
| return false; | return false; | ||||
| @@ -129,8 +139,11 @@ namespace LLama | |||||
| { | { | ||||
| try | try | ||||
| { | { | ||||
| // Load the model | |||||
| var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); | var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); | ||||
| foreach (var adapter in loraAdapters) | |||||
| // Apply the LoRA adapters | |||||
| for (var i = 0; i < loraAdapters.Length; i++) | |||||
| { | { | ||||
| // Interrupt applying LoRAs if the token is cancelled | // Interrupt applying LoRAs if the token is cancelled | ||||
| if (token.IsCancellationRequested) | if (token.IsCancellationRequested) | ||||
| @@ -140,14 +153,22 @@ namespace LLama | |||||
| } | } | ||||
| // Don't apply invalid adapters | // Don't apply invalid adapters | ||||
| var adapter = loraAdapters[i]; | |||||
| if (string.IsNullOrEmpty(adapter.Path)) | if (string.IsNullOrEmpty(adapter.Path)) | ||||
| continue; | continue; | ||||
| if (adapter.Scale <= 0) | if (adapter.Scale <= 0) | ||||
| continue; | continue; | ||||
| weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); | weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); | ||||
| // Report progress. Model loading reported progress from 0 -> 0.9, use | |||||
| // the last 0.1 to represent all of the LoRA adapters being applied. | |||||
| progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1)); | |||||
| } | } | ||||
| // Update progress reporter to indicate completion | |||||
| progressReporter?.Report(1); | |||||
| return new LLamaWeights(weights); | return new LLamaWeights(weights); | ||||
| } | } | ||||
| catch (LoadWeightsFailedException) | catch (LoadWeightsFailedException) | ||||