diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index 4be58c95..20a3e348 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -106,7 +106,7 @@ namespace LLama.Web.Common
///
/// how split tensors should be distributed across GPUs
///
- public float[] TensorSplits { get; set; }
+ public TensorSplitsCollection TensorSplits { get; set; } = new();
///
/// RoPE base frequency
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 1ec7022f..42f4f63a 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -1,6 +1,8 @@
using System;
+using System.Buffers;
using System.Collections.Generic;
using System.Linq;
+using LLama.Native;
namespace LLama.Abstractions
{
@@ -37,7 +39,7 @@ namespace LLama.Abstractions
///
/// how split tensors should be distributed across GPUs
///
- float[]? TensorSplits { get; set; }
+ TensorSplitsCollection TensorSplits { get; set; }
///
/// Load vocab only (no weights)
@@ -98,4 +100,42 @@ namespace LLama.Abstractions
}
}
}
+
+ ///
+ /// A fixed size array to set the tensor splits across multiple GPUs
+ ///
+ public sealed class TensorSplitsCollection
+ {
+ private readonly float[] _array = new float[NativeApi.llama_max_devices()];
+
+ ///
+ /// The size of this array
+ ///
+ public int Length => _array.Length;
+
+ ///
+ /// Get or set the proportion of work to do on the given device.
+ ///
+ /// "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.
+ ///
+ ///
+ public float this[int index]
+ {
+ get => _array[index];
+ set => _array[index] = value;
+ }
+
+ ///
+ /// Set all values to zero
+ ///
+ public void Clear()
+ {
+ Array.Clear(_array, 0, _array.Length);
+ }
+
+ internal MemoryHandle Pin()
+ {
+ return _array.AsMemory().Pin();
+ }
+ }
}
\ No newline at end of file
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 998d4ec4..bc02de63 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -82,9 +82,10 @@ namespace LLama.Common
public bool EmbeddingMode { get; set; }
///
- /// how split tensors should be distributed across GPUs
+ /// how split tensors should be distributed across GPUs.
///
- public float[]? TensorSplits { get; set; }
+ /// "[ 3, 2 ]" will assign 60% of the data to GPU 0 and 40% to GPU 1.
+ public TensorSplitsCollection TensorSplits { get; set; }
///
/// RoPE base frequency
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index dc72d239..a9c2d10e 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -1,7 +1,6 @@
using System.IO;
using System;
using System.Buffers;
-using System.Diagnostics;
using LLama.Abstractions;
using LLama.Native;
@@ -22,25 +21,6 @@ namespace LLama.Extensions
///
public static MemoryHandle ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result)
{
- var maxDevices = NativeApi.llama_max_devices();
- var splits = @params.TensorSplits;
- if (splits != null)
- {
- Debug.Assert(@params.TensorSplits != null);
-
- // If the splits array is too large just throw
- if (splits.Length > maxDevices)
- throw new ArgumentException($"TensorSplits size must be <= NativeApi.llama_max_devices() ({maxDevices})");
-
- // If the splits array is too small pad it up to the necessary size
- if (splits.Length < maxDevices)
- {
- splits = new float[maxDevices];
- for (var i = 0; i < @params.TensorSplits.Length; i++)
- splits[i] = @params.TensorSplits[i];
- }
- }
-
result = NativeApi.llama_model_default_params();
result.main_gpu = @params.MainGpu;
@@ -49,7 +29,7 @@ namespace LLama.Extensions
result.use_mmap = @params.UseMemorymap;
result.vocab_only = @params.VocabOnly;
- var pin = splits.AsMemory().Pin();
+ var pin = @params.TensorSplits.Pin();
unsafe
{
result.tensor_split = (float*)pin.Pointer;