diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index e37de1e9..ce712b72 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -92,10 +92,11 @@ namespace LLama
///
/// Parameters to use to load the model
/// A cancellation token that can interrupt model loading
+ /// Receives progress updates as the model loads (0 to 1)
///
/// Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.
/// Thrown if the cancellation token is cancelled.
- public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default)
+ public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress? progressReporter = null)
{
// don't touch the @params object inside the task, it might be changed
// externally! Save a copy of everything that we need later.
@@ -103,16 +104,25 @@ namespace LLama
var loraBase = @params.LoraBase;
var loraAdapters = @params.LoraAdapters.ToArray();
+ // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
+ // slightly smaller range to allow some space for reporting LoRA loading too.
+ var modelLoadProgressRange = 1f;
+ if (loraAdapters.Length > 0)
+ modelLoadProgressRange = 0.9f;
+
using (@params.ToLlamaModelParams(out var lparams))
{
#if !NETSTANDARD2_0
- // Overwrite the progress callback with one which polls the cancellation token
- if (token.CanBeCanceled)
+ // Overwrite the progress callback with one which polls the cancellation token and updates the progress object
+ if (token.CanBeCanceled || progressReporter != null)
{
var internalCallback = lparams.progress_callback;
lparams.progress_callback = (progress, ctx) =>
{
- // If the user set a call in the model params, first call that and see if we should cancel
+ // Update the progress reporter (remapping the value into the smaller range).
+ progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);
+
+ // If the user set a callback in the model params, call that and see if we should cancel
if (internalCallback != null && !internalCallback(progress, ctx))
return false;
@@ -129,8 +139,11 @@ namespace LLama
{
try
{
+ // Load the model
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
- foreach (var adapter in loraAdapters)
+
+ // Apply the LoRA adapters
+ for (var i = 0; i < loraAdapters.Length; i++)
{
// Interrupt applying LoRAs if the token is cancelled
if (token.IsCancellationRequested)
@@ -140,14 +153,22 @@ namespace LLama
}
// Don't apply invalid adapters
+ var adapter = loraAdapters[i];
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
+
+ // Report progress. Model loading reported progress from 0 -> 0.9, use
+ // the last 0.1 to represent all of the LoRA adapters being applied.
+ progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
}
+ // Update progress reporter to indicate completion
+ progressReporter?.Report(1);
+
return new LLamaWeights(weights);
}
catch (LoadWeightsFailedException)