| @@ -7,6 +7,7 @@ Console.WriteLine(" __ __ ____ _ | |||
| Console.WriteLine("======================================================================================================"); | |||
| NativeLibraryConfig.Default.WithCuda().WithLogs(); | |||
| NativeApi.llama_empty_call(); | |||
| Console.WriteLine(); | |||
| @@ -1,66 +1,67 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace LLama; | |||
| internal sealed class AntipromptProcessor | |||
| namespace LLama | |||
| { | |||
| private int _longestAntiprompt; | |||
| private readonly List<string> _antiprompts = new(); | |||
| private string? _string; | |||
| public AntipromptProcessor(IEnumerable<string>? antiprompts = null) | |||
| internal sealed class AntipromptProcessor | |||
| { | |||
| if (antiprompts != null) | |||
| SetAntiprompts(antiprompts); | |||
| } | |||
| private int _longestAntiprompt; | |||
| private readonly List<string> _antiprompts = new(); | |||
| /// <summary> | |||
| /// Add an antiprompt to the collection | |||
| /// </summary> | |||
| /// <param name="antiprompt"></param> | |||
| public void AddAntiprompt(string antiprompt) | |||
| { | |||
| _antiprompts.Add(antiprompt); | |||
| _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); | |||
| } | |||
| private string? _string; | |||
| /// <summary> | |||
| /// Overwrite all current antiprompts with a new set | |||
| /// </summary> | |||
| /// <param name="antiprompts"></param> | |||
| public void SetAntiprompts(IEnumerable<string> antiprompts) | |||
| { | |||
| _antiprompts.Clear(); | |||
| _antiprompts.AddRange(antiprompts); | |||
| public AntipromptProcessor(IEnumerable<string>? antiprompts = null) | |||
| { | |||
| if (antiprompts != null) | |||
| SetAntiprompts(antiprompts); | |||
| } | |||
| _longestAntiprompt = 0; | |||
| foreach (var antiprompt in _antiprompts) | |||
| /// <summary> | |||
| /// Add an antiprompt to the collection | |||
| /// </summary> | |||
| /// <param name="antiprompt"></param> | |||
| public void AddAntiprompt(string antiprompt) | |||
| { | |||
| _antiprompts.Add(antiprompt); | |||
| _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add some text and check if the buffer now ends with any antiprompt | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <returns>true if the text buffer ends with any antiprompt</returns> | |||
| public bool Add(string text) | |||
| { | |||
| _string += text; | |||
| /// <summary> | |||
| /// Overwrite all current antiprompts with a new set | |||
| /// </summary> | |||
| /// <param name="antiprompts"></param> | |||
| public void SetAntiprompts(IEnumerable<string> antiprompts) | |||
| { | |||
| _antiprompts.Clear(); | |||
| _antiprompts.AddRange(antiprompts); | |||
| _longestAntiprompt = 0; | |||
| foreach (var antiprompt in _antiprompts) | |||
| _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); | |||
| } | |||
| /// <summary> | |||
| /// Add some text and check if the buffer now ends with any antiprompt | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <returns>true if the text buffer ends with any antiprompt</returns> | |||
| public bool Add(string text) | |||
| { | |||
| _string += text; | |||
| // When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length). | |||
| // This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode | |||
| // even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances! | |||
| var maxLength = Math.Max(32, _longestAntiprompt * 4); | |||
| var trimLength = Math.Max(16, _longestAntiprompt * 2); | |||
| if (_string.Length > maxLength) | |||
| _string = _string.Substring(_string.Length - trimLength); | |||
| // When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length). | |||
| // This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode | |||
| // even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances! | |||
| var maxLength = Math.Max(32, _longestAntiprompt * 4); | |||
| var trimLength = Math.Max(16, _longestAntiprompt * 2); | |||
| if (_string.Length > maxLength) | |||
| _string = _string.Substring(_string.Length - trimLength); | |||
| foreach (var antiprompt in _antiprompts) | |||
| if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture)) | |||
| return true; | |||
| foreach (var antiprompt in _antiprompts) | |||
| if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture)) | |||
| return true; | |||
| return false; | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,356 @@ | |||
| using LLama.Exceptions; | |||
| using Microsoft.Extensions.Logging; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using System.Text.Json; | |||
| namespace LLama.Native | |||
| { | |||
| public partial class NativeApi | |||
| { | |||
| static NativeApi() | |||
| { | |||
| // Try to load a preferred library, based on CPU feature detection | |||
| TryLoadLibrary(); | |||
| try | |||
| { | |||
| llama_empty_call(); | |||
| } | |||
| catch (DllNotFoundException) | |||
| { | |||
| throw new RuntimeError("The native library cannot be found. It could be one of the following reasons: \n" + | |||
| "1. No LLamaSharp backend was installed. Please search LLamaSharp.Backend and install one of them. \n" + | |||
| "2. You are using a device with only CPU but installed cuda backend. Please install cpu backend instead. \n" + | |||
| "3. The backend is not compatible with your system cuda environment. Please check and fix it. If the environment is " + | |||
| "expected not to be changed, then consider build llama.cpp from source or submit an issue to LLamaSharp.\n" + | |||
| "4. One of the dependency of the native library is missed.\n"); | |||
| } | |||
| llama_backend_init(false); | |||
| } | |||
| private static void Log(string message, LogLevel level) | |||
| { | |||
| if (!enableLogging) return; | |||
| Debug.Assert(level is LogLevel.Information or LogLevel.Error or LogLevel.Warning); | |||
| ConsoleColor color; | |||
| string levelPrefix; | |||
| if (level == LogLevel.Information) | |||
| { | |||
| color = ConsoleColor.Green; | |||
| levelPrefix = "[Info]"; | |||
| } | |||
| else if (level == LogLevel.Error) | |||
| { | |||
| color = ConsoleColor.Red; | |||
| levelPrefix = "[Error]"; | |||
| } | |||
| else | |||
| { | |||
| color = ConsoleColor.Yellow; | |||
| levelPrefix = "[Error]"; | |||
| } | |||
| Console.ForegroundColor = color; | |||
| Console.WriteLine($"{loggingPrefix} {levelPrefix} {message}"); | |||
| Console.ResetColor(); | |||
| } | |||
| private static int GetCudaMajorVersion() | |||
| { | |||
| string? cudaPath; | |||
| string version = ""; | |||
| if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) | |||
| { | |||
| cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH"); | |||
| if (cudaPath is null) | |||
| { | |||
| return -1; | |||
| } | |||
| version = GetCudaVersionFromPath(cudaPath); | |||
| } | |||
| else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) | |||
| { | |||
| // Try the default first | |||
| cudaPath = "/usr/local/bin/cuda"; | |||
| version = GetCudaVersionFromPath(cudaPath); | |||
| if (string.IsNullOrEmpty(version)) | |||
| { | |||
| cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH"); | |||
| if (cudaPath is null) | |||
| { | |||
| return -1; | |||
| } | |||
| foreach (var path in cudaPath.Split(':')) | |||
| { | |||
| version = GetCudaVersionFromPath(Path.Combine(path, "..")); | |||
| if (string.IsNullOrEmpty(version)) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (string.IsNullOrEmpty(version)) | |||
| { | |||
| return -1; | |||
| } | |||
| else | |||
| { | |||
| version = version.Split('.')[0]; | |||
| bool success = int.TryParse(version, out var majorVersion); | |||
| if (success) | |||
| { | |||
| return majorVersion; | |||
| } | |||
| else | |||
| { | |||
| return -1; | |||
| } | |||
| } | |||
| } | |||
| private static string GetCudaVersionFromPath(string cudaPath) | |||
| { | |||
| try | |||
| { | |||
| string json = File.ReadAllText(Path.Combine(cudaPath, cudaVersionFile)); | |||
| using (JsonDocument document = JsonDocument.Parse(json)) | |||
| { | |||
| JsonElement root = document.RootElement; | |||
| JsonElement cublasNode = root.GetProperty("libcublas"); | |||
| JsonElement versionNode = cublasNode.GetProperty("version"); | |||
| if (versionNode.ValueKind == JsonValueKind.Undefined) | |||
| { | |||
| return string.Empty; | |||
| } | |||
| return versionNode.GetString(); | |||
| } | |||
| } | |||
| catch (Exception) | |||
| { | |||
| return string.Empty; | |||
| } | |||
| } | |||
| #if NET6_0_OR_GREATER | |||
| private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix) | |||
| { | |||
| var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel); | |||
| if (!string.IsNullOrEmpty(avxStr)) | |||
| { | |||
| avxStr += "/"; | |||
| } | |||
| return $"{prefix}{avxStr}{libraryName}{suffix}"; | |||
| } | |||
| private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration) | |||
| { | |||
| OSPlatform platform; | |||
| string prefix, suffix; | |||
| if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) | |||
| { | |||
| platform = OSPlatform.Windows; | |||
| prefix = "runtimes/win-x64/native/"; | |||
| suffix = ".dll"; | |||
| } | |||
| else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) | |||
| { | |||
| platform = OSPlatform.Linux; | |||
| prefix = "runtimes/linux-x64/native/"; | |||
| suffix = ".so"; | |||
| } | |||
| else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) | |||
| { | |||
| platform = OSPlatform.OSX; | |||
| suffix = ".dylib"; | |||
| if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported) | |||
| { | |||
| prefix = "runtimes/osx-arm64/native/"; | |||
| } | |||
| else | |||
| { | |||
| prefix = "runtimes/osx-x64/native/"; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp."); | |||
| } | |||
| Log($"Detected OS Platform: {platform}", LogLevel.Information); | |||
| List<string> result = new(); | |||
| if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos | |||
| { | |||
| int cudaVersion = GetCudaMajorVersion(); | |||
| // TODO: load cuda library with avx | |||
| if (cudaVersion == -1 && !configuration.AllowFallback) | |||
| { | |||
| // if check skipped, we just try to load cuda libraries one by one. | |||
| if (configuration.SkipCheck) | |||
| { | |||
| result.Add($"{prefix}cuda12/{libraryName}{suffix}"); | |||
| result.Add($"{prefix}cuda11/{libraryName}{suffix}"); | |||
| } | |||
| else | |||
| { | |||
| throw new RuntimeError("Configured to load a cuda library but no cuda detected on your device."); | |||
| } | |||
| } | |||
| else if (cudaVersion == 11) | |||
| { | |||
| Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information); | |||
| result.Add($"{prefix}cuda11/{libraryName}{suffix}"); | |||
| } | |||
| else if (cudaVersion == 12) | |||
| { | |||
| Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information); | |||
| result.Add($"{prefix}cuda12/{libraryName}{suffix}"); | |||
| } | |||
| else if (cudaVersion > 0) | |||
| { | |||
| throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it."); | |||
| } | |||
| // otherwise no cuda detected but allow fallback | |||
| } | |||
| // use cpu (or mac possibly with metal) | |||
| if (!configuration.AllowFallback && platform != OSPlatform.OSX) | |||
| { | |||
| result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix)); | |||
| } | |||
| else if (platform != OSPlatform.OSX) // in macos there's absolutely no avx | |||
| { | |||
| #if NET8_0_OR_GREATER | |||
| if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx512) | |||
| { | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix))); | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix))); | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix))); | |||
| } | |||
| else | |||
| #endif | |||
| if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx2) | |||
| { | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix)); | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)); | |||
| } | |||
| else if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx) | |||
| { | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)); | |||
| } | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix)); | |||
| } | |||
| if (platform == OSPlatform.OSX) | |||
| { | |||
| result.Add($"{prefix}{libraryName}{suffix}"); | |||
| } | |||
| return result; | |||
| } | |||
| #endif | |||
| /// <summary> | |||
| /// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible | |||
| /// </summary> | |||
| /// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns> | |||
| private static IntPtr TryLoadLibrary() | |||
| { | |||
| #if NET6_0_OR_GREATER | |||
| var configuration = NativeLibraryConfig.CheckAndGatherDescription(); | |||
| enableLogging = configuration.Logging; | |||
| // We move the flag to avoid loading library when the variable is called else where. | |||
| NativeLibraryConfig.LibraryHasLoaded = true; | |||
| if (!string.IsNullOrEmpty(configuration.Path)) | |||
| { | |||
| // When loading the user specified library, there's no fallback. | |||
| var success = NativeLibrary.TryLoad(configuration.Path, out var result); | |||
| if (!success) | |||
| { | |||
| throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified."); | |||
| } | |||
| Log($"Successfully loaded the library [{configuration.Path}] specified by user", LogLevel.Information); | |||
| return result; | |||
| } | |||
| var libraryTryLoadOrder = GetLibraryTryOrder(configuration); | |||
| string[] possiblePathPrefix = new string[] { | |||
| System.AppDomain.CurrentDomain.BaseDirectory, | |||
| Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" | |||
| }; | |||
| var tryFindPath = (string filename) => | |||
| { | |||
| int i = 0; | |||
| while (!File.Exists(filename)) | |||
| { | |||
| if (i < possiblePathPrefix.Length) | |||
| { | |||
| filename = Path.Combine(possiblePathPrefix[i], filename); | |||
| i++; | |||
| } | |||
| else | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| return filename; | |||
| }; | |||
| foreach (var libraryPath in libraryTryLoadOrder) | |||
| { | |||
| var fullPath = tryFindPath(libraryPath); | |||
| var result = TryLoad(fullPath, true); | |||
| if (result is not null && result != IntPtr.Zero) | |||
| { | |||
| Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information); | |||
| return result ?? IntPtr.Zero; | |||
| } | |||
| else | |||
| { | |||
| Log($"Tried to load {fullPath} but failed.", LogLevel.Information); | |||
| } | |||
| } | |||
| if (!configuration.AllowFallback) | |||
| { | |||
| throw new RuntimeError("Failed to load the library that match your rule, please" + | |||
| " 1) check your rule." + | |||
| " 2) try to allow fallback." + | |||
| " 3) or open an issue if it's expected to be successful."); | |||
| } | |||
| #endif | |||
| Log($"No library was loaded before calling native apis. " + | |||
| $"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LogLevel.Warning); | |||
| return IntPtr.Zero; | |||
| #if NET6_0_OR_GREATER | |||
| // Try to load a DLL from the path if supported. Returns null if nothing is loaded. | |||
| static IntPtr? TryLoad(string path, bool supported = true) | |||
| { | |||
| if (!supported) | |||
| return null; | |||
| if (NativeLibrary.TryLoad(path, out var handle)) | |||
| return handle; | |||
| return null; | |||
| } | |||
| #endif | |||
| } | |||
| private const string libraryName = "libllama"; | |||
| private const string cudaVersionFile = "version.json"; | |||
| private const string loggingPrefix = "[LLamaSharp Native]"; | |||
| private static bool enableLogging = false; | |||
| } | |||
| } | |||
| @@ -1,11 +1,7 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using System.Text.Json; | |||
| using LLama.Exceptions; | |||
| #pragma warning disable IDE1006 // Naming Styles | |||
| @@ -25,312 +21,6 @@ namespace LLama.Native | |||
| /// </summary> | |||
| public unsafe partial class NativeApi | |||
| { | |||
| static NativeApi() | |||
| { | |||
| // Try to load a preferred library, based on CPU feature detection | |||
| TryLoadLibrary(); | |||
| try | |||
| { | |||
| llama_empty_call(); | |||
| } | |||
| catch (DllNotFoundException) | |||
| { | |||
| throw new RuntimeError("The native library cannot be found. It could be one of the following reasons: \n" + | |||
| "1. No LLamaSharp backend was installed. Please search LLamaSharp.Backend and install one of them. \n" + | |||
| "2. You are using a device with only CPU but installed cuda backend. Please install cpu backend instead. \n" + | |||
| "3. The backend is not compatible with your system cuda environment. Please check and fix it. If the environment is " + | |||
| "expected not to be changed, then consider build llama.cpp from source or submit an issue to LLamaSharp.\n" + | |||
| "4. One of the dependency of the native library is missed.\n"); | |||
| } | |||
| llama_backend_init(false); | |||
| } | |||
| private static int GetCudaMajorVersion() | |||
| { | |||
| string? cudaPath; | |||
| string version = ""; | |||
| if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) | |||
| { | |||
| cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH"); | |||
| if(cudaPath is null) | |||
| { | |||
| return -1; | |||
| } | |||
| version = GetCudaVersionFromPath(cudaPath); | |||
| } | |||
| else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) | |||
| { | |||
| // Try the default first | |||
| cudaPath = "/usr/local/bin/cuda"; | |||
| version = GetCudaVersionFromPath(cudaPath); | |||
| if (string.IsNullOrEmpty(version)) | |||
| { | |||
| cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH"); | |||
| if(cudaPath is null) | |||
| { | |||
| return -1; | |||
| } | |||
| foreach(var path in cudaPath.Split(':')) | |||
| { | |||
| version = GetCudaVersionFromPath(Path.Combine(path, "..")); | |||
| if (string.IsNullOrEmpty(version)) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (string.IsNullOrEmpty(version)) | |||
| { | |||
| return -1; | |||
| } | |||
| else | |||
| { | |||
| version = version.Split('.')[0]; | |||
| bool success = int.TryParse(version, out var majorVersion); | |||
| if (success) | |||
| { | |||
| return majorVersion; | |||
| } | |||
| else | |||
| { | |||
| return -1; | |||
| } | |||
| } | |||
| } | |||
| private static string GetCudaVersionFromPath(string cudaPath) | |||
| { | |||
| try | |||
| { | |||
| string json = File.ReadAllText(Path.Combine(cudaPath, cudaVersionFile)); | |||
| using (JsonDocument document = JsonDocument.Parse(json)) | |||
| { | |||
| JsonElement root = document.RootElement; | |||
| JsonElement cublasNode = root.GetProperty("libcublas"); | |||
| JsonElement versionNode = cublasNode.GetProperty("version"); | |||
| if (versionNode.ValueKind == JsonValueKind.Undefined) | |||
| { | |||
| return string.Empty; | |||
| } | |||
| return versionNode.GetString(); | |||
| } | |||
| } | |||
| catch (Exception) | |||
| { | |||
| return string.Empty; | |||
| } | |||
| } | |||
| #if NET6_0_OR_GREATER | |||
| private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix) | |||
| { | |||
| var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel); | |||
| if (!string.IsNullOrEmpty(avxStr)) | |||
| { | |||
| avxStr += "/"; | |||
| } | |||
| return $"{prefix}{avxStr}{libraryName}{suffix}"; | |||
| } | |||
| private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration) | |||
| { | |||
| OSPlatform platform; | |||
| string prefix, suffix; | |||
| if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) | |||
| { | |||
| platform = OSPlatform.Windows; | |||
| prefix = "runtimes/win-x64/native/"; | |||
| suffix = ".dll"; | |||
| } | |||
| else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) | |||
| { | |||
| platform = OSPlatform.Linux; | |||
| prefix = "runtimes/linux-x64/native/"; | |||
| suffix = ".so"; | |||
| } | |||
| else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) | |||
| { | |||
| platform = OSPlatform.OSX; | |||
| suffix = ".dylib"; | |||
| if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported) | |||
| { | |||
| prefix = "runtimes/osx-arm64/native/"; | |||
| } | |||
| else | |||
| { | |||
| prefix = "runtimes/osx-x64/native/"; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp."); | |||
| } | |||
| List<string> result = new(); | |||
| if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos | |||
| { | |||
| int cudaVersion = GetCudaMajorVersion(); | |||
| // TODO: load cuda library with avx | |||
| if (cudaVersion == -1 && !configuration.AllowFallback) | |||
| { | |||
| // if check skipped, we just try to load cuda libraries one by one. | |||
| if (configuration.SkipCheck) | |||
| { | |||
| result.Add($"{prefix}cuda12/{libraryName}{suffix}"); | |||
| result.Add($"{prefix}cuda11/{libraryName}{suffix}"); | |||
| } | |||
| else | |||
| { | |||
| throw new RuntimeError("Configured to load a cuda library but no cuda detected on your device."); | |||
| } | |||
| } | |||
| else if (cudaVersion == 11) | |||
| { | |||
| result.Add($"{prefix}cuda11/{libraryName}{suffix}"); | |||
| } | |||
| else if (cudaVersion == 12) | |||
| { | |||
| result.Add($"{prefix}cuda12/{libraryName}{suffix}"); | |||
| } | |||
| else if (cudaVersion > 0) | |||
| { | |||
| throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it."); | |||
| } | |||
| // otherwise no cuda detected but allow fallback | |||
| } | |||
| // use cpu (or mac possibly with metal) | |||
| if (!configuration.AllowFallback && platform != OSPlatform.OSX) | |||
| { | |||
| result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix)); | |||
| } | |||
| else if(platform != OSPlatform.OSX) // in macos there's absolutely no avx | |||
| { | |||
| #if NET8_0_OR_GREATER | |||
| if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx512) | |||
| { | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix))); | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix))); | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix))); | |||
| } | |||
| else | |||
| #endif | |||
| if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx2) | |||
| { | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix)); | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)); | |||
| } | |||
| else if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx) | |||
| { | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)); | |||
| } | |||
| result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix)); | |||
| } | |||
| if(platform == OSPlatform.OSX) | |||
| { | |||
| result.Add($"{prefix}{libraryName}{suffix}"); | |||
| } | |||
| return result; | |||
| } | |||
| #endif | |||
| /// <summary> | |||
| /// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible | |||
| /// </summary> | |||
| /// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns> | |||
| private static IntPtr TryLoadLibrary() | |||
| { | |||
| #if NET6_0_OR_GREATER | |||
| var configuration = NativeLibraryConfig.GetInstance().Desc; | |||
| if (!string.IsNullOrEmpty(configuration.Path)) | |||
| { | |||
| // When loading the user specified library, there's no fallback. | |||
| var result = TryLoad(configuration.Path, true); | |||
| if (result is null || result == IntPtr.Zero) | |||
| { | |||
| throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified."); | |||
| } | |||
| return result ?? IntPtr.Zero; | |||
| } | |||
| var libraryTryLoadOrder = GetLibraryTryOrder(configuration); | |||
| string[] possiblePathPrefix = new string[] { | |||
| System.AppDomain.CurrentDomain.BaseDirectory, | |||
| Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" | |||
| }; | |||
| var tryFindPath = (string filename) => | |||
| { | |||
| int i = 0; | |||
| while (!File.Exists(filename)) | |||
| { | |||
| if (i < possiblePathPrefix.Length) | |||
| { | |||
| filename = Path.Combine(possiblePathPrefix[i], filename); | |||
| i++; | |||
| } | |||
| else | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| return filename; | |||
| }; | |||
| foreach (var libraryPath in libraryTryLoadOrder) | |||
| { | |||
| var fullPath = tryFindPath(libraryPath); | |||
| var result = TryLoad(fullPath, true); | |||
| if(result is not null && result != IntPtr.Zero) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.Red; | |||
| Console.WriteLine($"[Native Library] {fullPath} is loaded."); | |||
| Console.ResetColor(); | |||
| return result ?? IntPtr.Zero; | |||
| } | |||
| else | |||
| { | |||
| Console.WriteLine($"Tried to load {fullPath}"); | |||
| } | |||
| } | |||
| if (!configuration.AllowFallback) | |||
| { | |||
| throw new RuntimeError("Failed to load the library that match your rule, please" + | |||
| " 1) check your rule." + | |||
| " 2) try to allow fallback." + | |||
| " 3) or open an issue if it's expected to be successful."); | |||
| } | |||
| #endif | |||
| return IntPtr.Zero; | |||
| #if NET6_0_OR_GREATER | |||
| // Try to load a DLL from the path if supported. Returns null if nothing is loaded. | |||
| static IntPtr? TryLoad(string path, bool supported = true) | |||
| { | |||
| if (!supported) | |||
| return null; | |||
| if (NativeLibrary.TryLoad(path, out var handle)) | |||
| return handle; | |||
| return null; | |||
| } | |||
| #endif | |||
| } | |||
| private const string libraryName = "libllama"; | |||
| private const string cudaVersionFile = "version.json"; | |||
| /// <summary> | |||
| /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. | |||
| /// </summary> | |||
| @@ -0,0 +1,201 @@ | |||
| using System; | |||
| namespace LLama.Native | |||
| { | |||
| #if NET6_0_OR_GREATER | |||
| /// <summary> | |||
| /// A class about configurations when loading native libraries. | |||
| /// Note that it could be configured only once before any call to llama model apis. | |||
| /// </summary> | |||
| public class NativeLibraryConfig | |||
| { | |||
| private static NativeLibraryConfig? instance; | |||
| private static readonly object lockObject = new object(); | |||
| public static NativeLibraryConfig Default | |||
| { | |||
| get | |||
| { | |||
| return GetInstance(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Whether there's already a config for native library. | |||
| /// </summary> | |||
| public static bool LibraryHasLoaded { get; internal set; } = false; | |||
| private string _libraryPath; | |||
| private bool _useCuda; | |||
| private AvxLevel _avxLevel; | |||
| private bool _allowFallback; | |||
| private bool _skipCheck; | |||
| private bool _logging; | |||
| internal static NativeLibraryConfig GetInstance() | |||
| { | |||
| if (instance is null) | |||
| { | |||
| lock (lockObject) | |||
| { | |||
| if (instance is null) | |||
| { | |||
| instance = new NativeLibraryConfig(); | |||
| } | |||
| } | |||
| } | |||
| return instance; | |||
| } | |||
| /// <summary> | |||
| /// Load a specified native library as backend for LLamaSharp. | |||
| /// When this method is called, all the other configurations will be ignored. | |||
| /// </summary> | |||
| /// <param name="libraryPath"></param> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public NativeLibraryConfig WithLibrary(string libraryPath) | |||
| { | |||
| if (LibraryHasLoaded) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| _libraryPath = libraryPath; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Configure whether to use cuda backend if possible. | |||
| /// </summary> | |||
| /// <param name="enable"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public NativeLibraryConfig WithCuda(bool enable = true) | |||
| { | |||
| if (LibraryHasLoaded) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| _useCuda = enable; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Configure the prefferred avx support level of the backend. | |||
| /// </summary> | |||
| /// <param name="level"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public NativeLibraryConfig WithAvx(AvxLevel level) | |||
| { | |||
| if (LibraryHasLoaded) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| _avxLevel = level; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Configure whether to allow fallback when there's not match for preffered settings. | |||
| /// </summary> | |||
| /// <param name="enable"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public NativeLibraryConfig WithAutoFallback(bool enable = true) | |||
| { | |||
| if (LibraryHasLoaded) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| _allowFallback = enable; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Whether to skip the check when you don't allow fallback. This option | |||
| /// may be useful under some complex conditions. For example, you're sure | |||
| /// you have your cublas configured but LLamaSharp take it as invalid by mistake. | |||
| /// </summary> | |||
| /// <param name="enable"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public NativeLibraryConfig SkipCheck(bool enable = true) | |||
| { | |||
| if (LibraryHasLoaded) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| _skipCheck = enable; | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Whether to output the logs to console when loading the native library with your configuration. | |||
| /// </summary> | |||
| /// <param name="enable"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public NativeLibraryConfig WithLogs(bool enable = true) | |||
| { | |||
| if (LibraryHasLoaded) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| _logging = enable; | |||
| return this; | |||
| } | |||
| internal static Description CheckAndGatherDescription() | |||
| { | |||
| if (Default._allowFallback && Default._skipCheck) | |||
| { | |||
| throw new ArgumentException("Cannot skip the check when fallback is allowed."); | |||
| } | |||
| return new Description(Default._libraryPath, Default._useCuda, Default._avxLevel, Default._allowFallback, Default._skipCheck, Default._logging); | |||
| } | |||
| internal static string AvxLevelToString(AvxLevel level) | |||
| { | |||
| return level switch | |||
| { | |||
| AvxLevel.None => string.Empty, | |||
| AvxLevel.Avx => "avx", | |||
| AvxLevel.Avx2 => "avx2", | |||
| #if NET8_0_OR_GREATER | |||
| AvxLevel.Avx512 => "avx512" | |||
| #endif | |||
| _ => throw new ArgumentException($"Cannot recognize Avx level {level}") | |||
| }; | |||
| } | |||
| private NativeLibraryConfig() | |||
| { | |||
| _libraryPath = string.Empty; | |||
| _useCuda = true; | |||
| _avxLevel = AvxLevel.Avx2; | |||
| _allowFallback = true; | |||
| _skipCheck = false; | |||
| _logging = false; | |||
| } | |||
| /// <summary> | |||
| /// Avx support configuration | |||
| /// </summary> | |||
| public enum AvxLevel | |||
| { | |||
| /// <inheritdoc /> | |||
| None = 0, | |||
| /// <inheritdoc /> | |||
| Avx = 1, | |||
| /// <inheritdoc /> | |||
| Avx2 = 2, | |||
| #if NET8_0_OR_GREATER | |||
| /// <inheritdoc /> | |||
| Avx512 = 3, | |||
| #endif | |||
| } | |||
| internal record Description(string Path = "", bool UseCuda = true, AvxLevel AvxLevel = AvxLevel.Avx2, | |||
| bool AllowFallback = true, bool SkipCheck = false, bool Logging = false); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -1,113 +0,0 @@ | |||
| using System; | |||
| namespace LLama | |||
| { | |||
| #if NET6_0_OR_GREATER | |||
| /// <summary> | |||
| /// A class about configurations when loading native libraries. | |||
| /// Note that it could be configured only once before any call to llama model apis. | |||
| /// </summary> | |||
| public class NativeLibraryConfig | |||
| { | |||
| private static NativeLibraryConfig? instance; | |||
| private static readonly object lockObject = new object(); | |||
| /// <summary> | |||
| /// Whether there's already a config for native library. | |||
| /// </summary> | |||
| public bool Initialied { get; private set; } | |||
| internal Description Desc { get; private set; } | |||
| internal static NativeLibraryConfig GetInstance() | |||
| { | |||
| if (instance is null) | |||
| { | |||
| lock (lockObject) | |||
| { | |||
| if (instance is null) | |||
| { | |||
| instance = new NativeLibraryConfig(); | |||
| } | |||
| } | |||
| } | |||
| return instance; | |||
| } | |||
| /// <summary> | |||
| /// Load a specified native library as backend for LLamaSharp | |||
| /// </summary> | |||
| /// <param name="libraryPath"></param> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public static void WithLibrary(string libraryPath) | |||
| { | |||
| var config = GetInstance(); | |||
| if (config.Initialied) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| config.Desc = new Description(libraryPath); | |||
| } | |||
| /// <summary> | |||
| /// Ass rules to match a suitable library from installed LLamaSharp backend. | |||
| /// </summary> | |||
| /// <param name="useCuda"></param> | |||
| /// <param name="avxLevel"></param> | |||
| /// <param name="allowFallback">Whether to allow fall-back when your hardware doesn't support your configuration.</param> | |||
| /// <param name="skipCheck">Whether to skip the check when fallback is allowed. | |||
| /// It's especially useful when your cuda library is not in the default path. </param> | |||
| /// <exception cref="InvalidOperationException"></exception> | |||
| public static void WithMatchRule(bool useCuda = true, AvxLevel avxLevel = AvxLevel.Avx2, bool allowFallback = true, bool skipCheck = false) | |||
| { | |||
| if(allowFallback && skipCheck) | |||
| { | |||
| throw new ArgumentException("Cannot skip the check when fallback is allowed."); | |||
| } | |||
| var config = GetInstance(); | |||
| if (config.Initialied) | |||
| { | |||
| throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis."); | |||
| } | |||
| config.Desc = new Description(UseCuda: useCuda, AvxLevel: avxLevel, AllowFallback: allowFallback, SkipCheck: skipCheck); | |||
| } | |||
| internal static string AvxLevelToString(AvxLevel level) | |||
| { | |||
| return level switch | |||
| { | |||
| AvxLevel.None => string.Empty, | |||
| AvxLevel.Avx => "avx", | |||
| AvxLevel.Avx2 => "avx2", | |||
| #if NET8_0_OR_GREATER | |||
| AvxLevel.Avx512 => "avx512" | |||
| #endif | |||
| _ => throw new ArgumentException($"Cannot recognize Avx level {level}") | |||
| }; | |||
| } | |||
| private NativeLibraryConfig() | |||
| { | |||
| Desc = new Description(); | |||
| } | |||
| /// <summary> | |||
| /// Avx support configuration | |||
| /// </summary> | |||
| public enum AvxLevel | |||
| { | |||
| /// <inheritdoc /> | |||
| None = 0, | |||
| /// <inheritdoc /> | |||
| Avx = 1, | |||
| /// <inheritdoc /> | |||
| Avx2 = 2, | |||
| #if NET8_0_OR_GREATER | |||
| /// <inheritdoc /> | |||
| Avx512 = 3, | |||
| #endif | |||
| } | |||
| internal record Description(string Path = "", bool UseCuda = true, AvxLevel AvxLevel = AvxLevel.Avx2, bool AllowFallback = true, bool SkipCheck = false); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -6,169 +6,170 @@ using System.Text; | |||
| using LLama.Extensions; | |||
| using LLama.Native; | |||
| namespace LLama; | |||
| /// <summary> | |||
| /// Decodes a stream of tokens into a stream of characters | |||
| /// </summary> | |||
| public sealed class StreamingTokenDecoder | |||
| namespace LLama | |||
| { | |||
| private readonly SafeLlamaModelHandle _weights; | |||
| private readonly Decoder _decoder; | |||
| private readonly List<char> _characters = new(); | |||
| /// <summary> | |||
| /// The number of decoded characters waiting to be read | |||
| /// </summary> | |||
| public int AvailableCharacters => _characters.Count; | |||
| #region constructors | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// </summary> | |||
| /// <param name="encoding">Text encoding to use</param> | |||
| /// <param name="weights">Model weights</param> | |||
| public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) | |||
| : this(encoding, weights.NativeHandle) | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// </summary> | |||
| /// <param name="context">Context to retrieve encoding and model weights from</param> | |||
| public StreamingTokenDecoder(LLamaContext context) | |||
| : this(context.Encoding, context.NativeHandle) | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// Decodes a stream of tokens into a stream of characters | |||
| /// </summary> | |||
| /// <param name="encoding">Text encoding to use</param> | |||
| /// <param name="context">Context to retrieve model weights from</param> | |||
| public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context) | |||
| : this(encoding, context.ModelHandle) | |||
| public sealed class StreamingTokenDecoder | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// </summary> | |||
| /// <param name="encoding">Text encoding to use</param> | |||
| /// <param name="weights">Models weights to use</param> | |||
| public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights) | |||
| { | |||
| _weights = weights; | |||
| _decoder = encoding.GetDecoder(); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Add a single token to the decoder | |||
| /// </summary> | |||
| /// <param name="token"></param> | |||
| public void Add(int token) | |||
| { | |||
| var charsArr = ArrayPool<char>.Shared.Rent(16); | |||
| var bytesArr = ArrayPool<byte>.Shared.Rent(16); | |||
| try | |||
| private readonly SafeLlamaModelHandle _weights; | |||
| private readonly Decoder _decoder; | |||
| private readonly List<char> _characters = new(); | |||
| /// <summary> | |||
| /// The number of decoded characters waiting to be read | |||
| /// </summary> | |||
| public int AvailableCharacters => _characters.Count; | |||
| #region constructors | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// </summary> | |||
| /// <param name="encoding">Text encoding to use</param> | |||
| /// <param name="weights">Model weights</param> | |||
| public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) | |||
| : this(encoding, weights.NativeHandle) | |||
| { | |||
| // Convert this token into bytes | |||
| var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; | |||
| // Convert those bytes into characters | |||
| var bytesOffset = 0; | |||
| var completed = false; | |||
| while (!completed) | |||
| { | |||
| // Decode some of the bytes into the temp char buffer. Keep doing this | |||
| // until all bytes have been consumed | |||
| _decoder.Convert( | |||
| bytesArr, bytesOffset, bytesAvailable, | |||
| charsArr, 0, charsArr.Length, | |||
| false, | |||
| out var bytesUsed, out var charsUsed, out completed | |||
| ); | |||
| bytesOffset += bytesUsed; | |||
| bytesAvailable -= bytesUsed; | |||
| // Add the decoded characters to the output buffer | |||
| _characters.AddSpan(charsArr.AsSpan(0, charsUsed)); | |||
| } | |||
| } | |||
| finally | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// </summary> | |||
| /// <param name="context">Context to retrieve encoding and model weights from</param> | |||
| public StreamingTokenDecoder(LLamaContext context) | |||
| : this(context.Encoding, context.NativeHandle) | |||
| { | |||
| ArrayPool<char>.Shared.Return(charsArr); | |||
| ArrayPool<byte>.Shared.Return(bytesArr); | |||
| } | |||
| return; | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// </summary> | |||
| /// <param name="encoding">Text encoding to use</param> | |||
| /// <param name="context">Context to retrieve model weights from</param> | |||
| public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context) | |||
| : this(encoding, context.ModelHandle) | |||
| { | |||
| } | |||
| // Converts a single token into bytes, using the `bytes` array as temporary storage. | |||
| // If the `bytes` array is too small it will get a larger one from the ArrayPool. | |||
| static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) | |||
| /// <summary> | |||
| /// Create a new decoder | |||
| /// </summary> | |||
| /// <param name="encoding">Text encoding to use</param> | |||
| /// <param name="weights">Models weights to use</param> | |||
| public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights) | |||
| { | |||
| // Try to get bytes | |||
| var l = model.TokenToSpan(token, bytes); | |||
| _weights = weights; | |||
| _decoder = encoding.GetDecoder(); | |||
| } | |||
| #endregion | |||
| // Negative length indicates that the output was too small. Expand it to twice that size and try again. | |||
| if (l < 0) | |||
| /// <summary> | |||
| /// Add a single token to the decoder | |||
| /// </summary> | |||
| /// <param name="token"></param> | |||
| public void Add(int token) | |||
| { | |||
| var charsArr = ArrayPool<char>.Shared.Rent(16); | |||
| var bytesArr = ArrayPool<byte>.Shared.Rent(16); | |||
| try | |||
| { | |||
| // Return the old array to the pool and get a new one | |||
| ArrayPool<byte>.Shared.Return(bytes); | |||
| bytes = ArrayPool<byte>.Shared.Rent(-l * 2); | |||
| // Get bytes, this time it can't fail | |||
| l = model.TokenToSpan(token, bytes); | |||
| // Convert this token into bytes | |||
| var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; | |||
| // Convert those bytes into characters | |||
| var bytesOffset = 0; | |||
| var completed = false; | |||
| while (!completed) | |||
| { | |||
| // Decode some of the bytes into the temp char buffer. Keep doing this | |||
| // until all bytes have been consumed | |||
| _decoder.Convert( | |||
| bytesArr, bytesOffset, bytesAvailable, | |||
| charsArr, 0, charsArr.Length, | |||
| false, | |||
| out var bytesUsed, out var charsUsed, out completed | |||
| ); | |||
| bytesOffset += bytesUsed; | |||
| bytesAvailable -= bytesUsed; | |||
| // Add the decoded characters to the output buffer | |||
| _characters.AddSpan(charsArr.AsSpan(0, charsUsed)); | |||
| } | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<char>.Shared.Return(charsArr); | |||
| ArrayPool<byte>.Shared.Return(bytesArr); | |||
| } | |||
| Debug.Assert(l >= 0); | |||
| return new Span<byte>(bytes, 0, l); | |||
| return; | |||
| // Converts a single token into bytes, using the `bytes` array as temporary storage. | |||
| // If the `bytes` array is too small it will get a larger one from the ArrayPool. | |||
| static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) | |||
| { | |||
| // Try to get bytes | |||
| var l = model.TokenToSpan(token, bytes); | |||
| // Negative length indicates that the output was too small. Expand it to twice that size and try again. | |||
| if (l < 0) | |||
| { | |||
| // Return the old array to the pool and get a new one | |||
| ArrayPool<byte>.Shared.Return(bytes); | |||
| bytes = ArrayPool<byte>.Shared.Rent(-l * 2); | |||
| // Get bytes, this time it can't fail | |||
| l = model.TokenToSpan(token, bytes); | |||
| } | |||
| Debug.Assert(l >= 0); | |||
| return new Span<byte>(bytes, 0, l); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Add all tokens in the given enumerable | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| public void AddRange(IEnumerable<int> tokens) | |||
| { | |||
| foreach (var item in tokens) | |||
| Add(item); | |||
| } | |||
| /// <summary> | |||
| /// Add all tokens in the given enumerable | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| public void AddRange(IEnumerable<int> tokens) | |||
| { | |||
| foreach (var item in tokens) | |||
| Add(item); | |||
| } | |||
| /// <summary> | |||
| /// Read all decoded characters and clear the buffer | |||
| /// </summary> | |||
| /// <param name="dest"></param> | |||
| public void Read(List<char> dest) | |||
| { | |||
| dest.AddRange(_characters); | |||
| _characters.Clear(); | |||
| } | |||
| /// <summary> | |||
| /// Read all decoded characters and clear the buffer | |||
| /// </summary> | |||
| /// <param name="dest"></param> | |||
| public void Read(List<char> dest) | |||
| { | |||
| dest.AddRange(_characters); | |||
| _characters.Clear(); | |||
| } | |||
| /// <summary> | |||
| /// Read all decoded characters as a string and clear the buffer | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public string Read() | |||
| { | |||
| if (_characters.Count == 0) | |||
| return ""; | |||
| /// <summary> | |||
| /// Read all decoded characters as a string and clear the buffer | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public string Read() | |||
| { | |||
| if (_characters.Count == 0) | |||
| return ""; | |||
| var str = string.Join("", _characters); | |||
| _characters.Clear(); | |||
| return str; | |||
| } | |||
| var str = string.Join("", _characters); | |||
| _characters.Clear(); | |||
| return str; | |||
| } | |||
| /// <summary> | |||
| /// Set the decoder back to its initial state | |||
| /// </summary> | |||
| public void Reset() | |||
| { | |||
| _decoder.Reset(); | |||
| _characters.Clear(); | |||
| /// <summary> | |||
| /// Set the decoder back to its initial state | |||
| /// </summary> | |||
| public void Reset() | |||
| { | |||
| _decoder.Reset(); | |||
| _characters.Clear(); | |||
| } | |||
| } | |||
| } | |||
| } | |||