diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs
index febba5c3..2c401822 100644
--- a/LLama.Examples/Examples/BatchedExecutorFork.cs
+++ b/LLama.Examples/Examples/BatchedExecutorFork.cs
@@ -19,7 +19,7 @@ public class BatchedExecutorFork
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs
index 6f3eceab..b006c88b 100644
--- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs
+++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs
@@ -19,7 +19,7 @@ public class BatchedExecutorGuidance
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
diff --git a/LLama.Examples/Examples/BatchedExecutorRewind.cs b/LLama.Examples/Examples/BatchedExecutorRewind.cs
index 8aae92af..938b3106 100644
--- a/LLama.Examples/Examples/BatchedExecutorRewind.cs
+++ b/LLama.Examples/Examples/BatchedExecutorRewind.cs
@@ -20,7 +20,7 @@ public class BatchedExecutorRewind
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
diff --git a/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
index af0dea52..0ec903eb 100644
--- a/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
+++ b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
@@ -18,7 +18,7 @@ public class BatchedExecutorSaveAndLoad
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs
index a5db02cd..c59a522f 100644
--- a/LLama.Examples/Examples/ChatChineseGB2312.cs
+++ b/LLama.Examples/Examples/ChatChineseGB2312.cs
@@ -31,7 +31,7 @@ public class ChatChineseGB2312
GpuLayerCount = 5,
Encoding = Encoding.UTF8
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs
index ff0b369d..5469aa8f 100644
--- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs
+++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs
@@ -15,11 +15,11 @@ public class ChatSessionStripRoleName
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
- var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
+ var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession session = new(executor, chatHistory);
diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs
index da5e3ad0..af7d7eac 100644
--- a/LLama.Examples/Examples/ChatSessionWithHistory.cs
+++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs
@@ -13,7 +13,7 @@ public class ChatSessionWithHistory
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs
index 48754a81..c2bfb895 100644
--- a/LLama.Examples/Examples/ChatSessionWithRestart.cs
+++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs
@@ -13,11 +13,11 @@ public class ChatSessionWithRestart
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
- var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
+ var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession prototypeSession =
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);
diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs
index 08f7666b..4e2befd9 100644
--- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs
+++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs
@@ -13,11 +13,11 @@ public class ChatSessionWithRoleName
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
- var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
+ var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession session = new(executor, chatHistory);
diff --git a/LLama.Examples/Examples/CodingAssistant.cs b/LLama.Examples/Examples/CodingAssistant.cs
index 808c3904..a2edf8be 100644
--- a/LLama.Examples/Examples/CodingAssistant.cs
+++ b/LLama.Examples/Examples/CodingAssistant.cs
@@ -29,7 +29,7 @@
{
ContextSize = 4096
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null);
diff --git a/LLama.Examples/Examples/GrammarJsonResponse.cs b/LLama.Examples/Examples/GrammarJsonResponse.cs
index 139bd4ac..a5bb5486 100644
--- a/LLama.Examples/Examples/GrammarJsonResponse.cs
+++ b/LLama.Examples/Examples/GrammarJsonResponse.cs
@@ -9,7 +9,7 @@ namespace LLama.Examples.Examples
{
string modelPath = UserSettings.GetModelPath();
- var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
+ var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
var grammar = Grammar.Parse(gbnf, "root");
var parameters = new ModelParams(modelPath)
@@ -17,7 +17,7 @@ namespace LLama.Examples.Examples
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);
Console.ForegroundColor = ConsoleColor.Yellow;
diff --git a/LLama.Examples/Examples/InstructModeExecute.cs b/LLama.Examples/Examples/InstructModeExecute.cs
index 73b5da79..4f65dd23 100644
--- a/LLama.Examples/Examples/InstructModeExecute.cs
+++ b/LLama.Examples/Examples/InstructModeExecute.cs
@@ -9,14 +9,14 @@ namespace LLama.Examples.Examples
{
string modelPath = UserSettings.GetModelPath();
- var prompt = File.ReadAllText("Assets/dan.txt").Trim();
+ var prompt = (await File.ReadAllTextAsync("Assets/dan.txt")).Trim();
var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context);
diff --git a/LLama.Examples/Examples/InteractiveModeExecute.cs b/LLama.Examples/Examples/InteractiveModeExecute.cs
index d7d364fb..15a9c94c 100644
--- a/LLama.Examples/Examples/InteractiveModeExecute.cs
+++ b/LLama.Examples/Examples/InteractiveModeExecute.cs
@@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
index 112fe23f..170bab0c 100644
--- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
+++ b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
@@ -20,7 +20,7 @@ namespace LLama.Examples.Examples
var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
// Llava Init
diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs
index d8a31017..68ed8aa3 100644
--- a/LLama.Examples/Examples/LoadAndSaveSession.cs
+++ b/LLama.Examples/Examples/LoadAndSaveSession.cs
@@ -15,7 +15,7 @@ namespace LLama.Examples.Examples
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
diff --git a/LLama.Examples/Examples/LoadAndSaveState.cs b/LLama.Examples/Examples/LoadAndSaveState.cs
index 9cf93e7f..0fef49f1 100644
--- a/LLama.Examples/Examples/LoadAndSaveState.cs
+++ b/LLama.Examples/Examples/LoadAndSaveState.cs
@@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
diff --git a/LLama.Examples/Examples/SemanticKernelChat.cs b/LLama.Examples/Examples/SemanticKernelChat.cs
index 258ca86b..2631cc9b 100644
--- a/LLama.Examples/Examples/SemanticKernelChat.cs
+++ b/LLama.Examples/Examples/SemanticKernelChat.cs
@@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
// Load weights into memory
var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);
var chatGPT = new LLamaSharpChatCompletion(ex);
diff --git a/LLama.Examples/Examples/SemanticKernelMemory.cs b/LLama.Examples/Examples/SemanticKernelMemory.cs
index 46c9a17d..3fad5ae0 100644
--- a/LLama.Examples/Examples/SemanticKernelMemory.cs
+++ b/LLama.Examples/Examples/SemanticKernelMemory.cs
@@ -23,7 +23,7 @@ namespace LLama.Examples.Examples
Embeddings = true
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var embedding = new LLamaEmbedder(model, parameters);
Console.WriteLine("====================================================");
diff --git a/LLama.Examples/Examples/SemanticKernelPrompt.cs b/LLama.Examples/Examples/SemanticKernelPrompt.cs
index fdf58b3a..63e848cb 100644
--- a/LLama.Examples/Examples/SemanticKernelPrompt.cs
+++ b/LLama.Examples/Examples/SemanticKernelPrompt.cs
@@ -19,7 +19,7 @@ namespace LLama.Examples.Examples
// Load weights into memory
var parameters = new ModelParams(modelPath);
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);
var builder = Kernel.CreateBuilder();
diff --git a/LLama.Examples/Examples/StatelessModeExecute.cs b/LLama.Examples/Examples/StatelessModeExecute.cs
index 4d2edd19..806616e7 100644
--- a/LLama.Examples/Examples/StatelessModeExecute.cs
+++ b/LLama.Examples/Examples/StatelessModeExecute.cs
@@ -14,7 +14,7 @@ namespace LLama.Examples.Examples
Seed = 1337,
GpuLayerCount = 5
};
- using var model = LLamaWeights.LoadFromFile(parameters);
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);
Console.ForegroundColor = ConsoleColor.Yellow;
diff --git a/LLama.Examples/Examples/TalkToYourself.cs b/LLama.Examples/Examples/TalkToYourself.cs
index bf72423f..f888209a 100644
--- a/LLama.Examples/Examples/TalkToYourself.cs
+++ b/LLama.Examples/Examples/TalkToYourself.cs
@@ -12,7 +12,7 @@ namespace LLama.Examples.Examples
// Load weights into memory
var @params = new ModelParams(modelPath);
- using var weights = LLamaWeights.LoadFromFile(@params);
+ using var weights = await LLamaWeights.LoadFromFileAsync(@params);
// Create 2 contexts sharing the same weights
using var aliceCtx = weights.CreateContext(@params);
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index 2d8ea4d9..ce712b72 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -1,7 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
using LLama.Abstractions;
+using LLama.Exceptions;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.Logging;
@@ -84,6 +87,104 @@ namespace LLama
return new LLamaWeights(weights);
}
+ ///
+ /// Load weights into memory
+ ///
+ /// Parameters to use to load the model
+ /// A cancellation token that can interrupt model loading
+ /// Receives progress updates as the model loads (0 to 1)
+ ///
+ /// Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.
+ /// Thrown if the cancellation token is cancelled.
+ public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress? progressReporter = null)
+ {
+ // don't touch the @params object inside the task, it might be changed
+ // externally! Save a copy of everything that we need later.
+ var modelPath = @params.ModelPath;
+ var loraBase = @params.LoraBase;
+ var loraAdapters = @params.LoraAdapters.ToArray();
+
+ // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
+ // slightly smaller range to allow some space for reporting LoRA loading too.
+ var modelLoadProgressRange = 1f;
+ if (loraAdapters.Length > 0)
+ modelLoadProgressRange = 0.9f;
+
+ using (@params.ToLlamaModelParams(out var lparams))
+ {
+#if !NETSTANDARD2_0
+ // Overwrite the progress callback with one which polls the cancellation token and updates the progress object
+ if (token.CanBeCanceled || progressReporter != null)
+ {
+ var internalCallback = lparams.progress_callback;
+ lparams.progress_callback = (progress, ctx) =>
+ {
+ // Update the progress reporter (remapping the value into the smaller range).
+ progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);
+
+ // If the user set a callback in the model params, call that and see if we should cancel
+ if (internalCallback != null && !internalCallback(progress, ctx))
+ return false;
+
+ // Check the cancellation token
+ if (token.IsCancellationRequested)
+ return false;
+
+ return true;
+ };
+ }
+#endif
+
+ var model = await Task.Run(() =>
+ {
+ try
+ {
+ // Load the model
+ var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
+
+ // Apply the LoRA adapters
+ for (var i = 0; i < loraAdapters.Length; i++)
+ {
+ // Interrupt applying LoRAs if the token is cancelled
+ if (token.IsCancellationRequested)
+ {
+ weights.Dispose();
+ token.ThrowIfCancellationRequested();
+ }
+
+ // Don't apply invalid adapters
+ var adapter = loraAdapters[i];
+ if (string.IsNullOrEmpty(adapter.Path))
+ continue;
+ if (adapter.Scale <= 0)
+ continue;
+
+ weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
+
+ // Report progress. Model loading reported progress from 0 -> 0.9, use
+ // the last 0.1 to represent all of the LoRA adapters being applied.
+ progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
+ }
+
+ // Update progress reporter to indicate completion
+ progressReporter?.Report(1);
+
+ return new LLamaWeights(weights);
+ }
+ catch (LoadWeightsFailedException)
+ {
+ // Convert a LoadWeightsFailedException into a cancellation exception if possible.
+ token.ThrowIfCancellationRequested();
+
+ // Ok the weights failed to load for some reason other than cancellation.
+ throw;
+ }
+ }, token);
+
+ return model;
+ }
+ }
+
///
public void Dispose()
{
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 8e3d7f74..1ea52e6b 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -8,6 +8,8 @@ namespace LLama.Native
///
///
///
+ /// If the provided progress_callback returns true, model loading continues.
+ /// If it returns false, model loading is immediately aborted.
/// llama_progress_callback
public delegate bool LlamaProgressCallback(float progress, IntPtr ctx);
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index 923b042c..6fca41fc 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -38,7 +38,7 @@ namespace LLama.Native
// as NET Framework 4.8 does not play nice with the LlamaProgressCallback type
public IntPtr progress_callback;
#else
- public LlamaProgressCallback progress_callback;
+ public LlamaProgressCallback? progress_callback;
#endif
///
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 23c1f767..2758c050 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -120,8 +120,11 @@ namespace LLama.Native
if (!fs.CanRead)
throw new InvalidOperationException($"Model file '{modelPath}' is not readable");
- return llama_load_model_from_file(modelPath, lparams)
- ?? throw new LoadWeightsFailedException(modelPath);
+ var handle = llama_load_model_from_file(modelPath, lparams);
+ if (handle.IsInvalid)
+ throw new LoadWeightsFailedException(modelPath);
+
+ return handle;
}
#region native API