Browse Source

Added optional `IProgress` parameter to `LoadFromFileAsync`

pull/702/head
Martin Evans 1 year ago
parent
commit
1ec0fee5ba
1 changed files with 26 additions and 5 deletions
  1. +26
    -5
      LLama/LLamaWeights.cs

+ 26
- 5
LLama/LLamaWeights.cs View File

@@ -92,10 +92,11 @@ namespace LLama
/// </summary> /// </summary>
/// <param name="params">Parameters to use to load the model</param> /// <param name="params">Parameters to use to load the model</param>
/// <param name="token">A cancellation token that can interrupt model loading</param> /// <param name="token">A cancellation token that can interrupt model loading</param>
/// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception> /// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
/// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception> /// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default)
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null)
{ {
// don't touch the @params object inside the task, it might be changed // don't touch the @params object inside the task, it might be changed
// externally! Save a copy of everything that we need later. // externally! Save a copy of everything that we need later.
@@ -103,16 +104,25 @@ namespace LLama
var loraBase = @params.LoraBase; var loraBase = @params.LoraBase;
var loraAdapters = @params.LoraAdapters.ToArray(); var loraAdapters = @params.LoraAdapters.ToArray();


// Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
// slightly smaller range to allow some space for reporting LoRA loading too.
var modelLoadProgressRange = 1f;
if (loraAdapters.Length > 0)
modelLoadProgressRange = 0.9f;

using (@params.ToLlamaModelParams(out var lparams)) using (@params.ToLlamaModelParams(out var lparams))
{ {
#if !NETSTANDARD2_0 #if !NETSTANDARD2_0
// Overwrite the progress callback with one which polls the cancellation token
if (token.CanBeCanceled)
// Overwrite the progress callback with one which polls the cancellation token and updates the progress object
if (token.CanBeCanceled || progressReporter != null)
{ {
var internalCallback = lparams.progress_callback; var internalCallback = lparams.progress_callback;
lparams.progress_callback = (progress, ctx) => lparams.progress_callback = (progress, ctx) =>
{ {
// If the user set a call in the model params, first call that and see if we should cancel
// Update the progress reporter (remapping the value into the smaller range).
progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);

// If the user set a callback in the model params, call that and see if we should cancel
if (internalCallback != null && !internalCallback(progress, ctx)) if (internalCallback != null && !internalCallback(progress, ctx))
return false; return false;


@@ -129,8 +139,11 @@ namespace LLama
{ {
try try
{ {
// Load the model
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
foreach (var adapter in loraAdapters)

// Apply the LoRA adapters
for (var i = 0; i < loraAdapters.Length; i++)
{ {
// Interrupt applying LoRAs if the token is cancelled // Interrupt applying LoRAs if the token is cancelled
if (token.IsCancellationRequested) if (token.IsCancellationRequested)
@@ -140,14 +153,22 @@ namespace LLama
} }


// Don't apply invalid adapters // Don't apply invalid adapters
var adapter = loraAdapters[i];
if (string.IsNullOrEmpty(adapter.Path)) if (string.IsNullOrEmpty(adapter.Path))
continue; continue;
if (adapter.Scale <= 0) if (adapter.Scale <= 0)
continue; continue;


weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);

// Report progress. Model loading reported progress from 0 -> 0.9, use
// the last 0.1 to represent all of the LoRA adapters being applied.
progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
} }


// Update progress reporter to indicate completion
progressReporter?.Report(1);

return new LLamaWeights(weights); return new LLamaWeights(weights);
} }
catch (LoadWeightsFailedException) catch (LoadWeightsFailedException)


Loading…
Cancel
Save