Browse Source

feat: optimize apis for cuda feature detection.

tags/v0.8.0
Yaohui Liu 2 years ago
parent
commit
cb5fb210b1
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
7 changed files with 755 additions and 618 deletions
  1. +1
    -0
      LLama.Examples/Program.cs
  2. +52
    -51
      LLama/AntipromptProcessor.cs
  3. +356
    -0
      LLama/Native/NativeApi.Load.cs
  4. +0
    -310
      LLama/Native/NativeApi.cs
  5. +201
    -0
      LLama/Native/NativeLibraryConfig.cs
  6. +0
    -113
      LLama/NativeLibraryConfig.cs
  7. +145
    -144
      LLama/StreamingTokenDecoder.cs

+ 1
- 0
LLama.Examples/Program.cs View File

@@ -7,6 +7,7 @@ Console.WriteLine(" __ __ ____ _

Console.WriteLine("======================================================================================================");

NativeLibraryConfig.Default.WithCuda().WithLogs();

NativeApi.llama_empty_call();
Console.WriteLine();


+ 52
- 51
LLama/AntipromptProcessor.cs View File

@@ -1,66 +1,67 @@
using System;
using System.Collections.Generic;

namespace LLama;

internal sealed class AntipromptProcessor
namespace LLama
{
private int _longestAntiprompt;
private readonly List<string> _antiprompts = new();

private string? _string;

public AntipromptProcessor(IEnumerable<string>? antiprompts = null)
internal sealed class AntipromptProcessor
{
if (antiprompts != null)
SetAntiprompts(antiprompts);
}
private int _longestAntiprompt;
private readonly List<string> _antiprompts = new();

/// <summary>
/// Add an antiprompt to the collection
/// </summary>
/// <param name="antiprompt"></param>
public void AddAntiprompt(string antiprompt)
{
_antiprompts.Add(antiprompt);
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}
private string? _string;

/// <summary>
/// Overwrite all current antiprompts with a new set
/// </summary>
/// <param name="antiprompts"></param>
public void SetAntiprompts(IEnumerable<string> antiprompts)
{
_antiprompts.Clear();
_antiprompts.AddRange(antiprompts);
public AntipromptProcessor(IEnumerable<string>? antiprompts = null)
{
if (antiprompts != null)
SetAntiprompts(antiprompts);
}

_longestAntiprompt = 0;
foreach (var antiprompt in _antiprompts)
/// <summary>
/// Add an antiprompt to the collection
/// </summary>
/// <param name="antiprompt"></param>
public void AddAntiprompt(string antiprompt)
{
_antiprompts.Add(antiprompt);
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}
}

/// <summary>
/// Add some text and check if the buffer now ends with any antiprompt
/// </summary>
/// <param name="text"></param>
/// <returns>true if the text buffer ends with any antiprompt</returns>
public bool Add(string text)
{
_string += text;
/// <summary>
/// Overwrite all current antiprompts with a new set
/// </summary>
/// <param name="antiprompts"></param>
public void SetAntiprompts(IEnumerable<string> antiprompts)
{
_antiprompts.Clear();
_antiprompts.AddRange(antiprompts);

_longestAntiprompt = 0;
foreach (var antiprompt in _antiprompts)
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}

/// <summary>
/// Add some text and check if the buffer now ends with any antiprompt
/// </summary>
/// <param name="text"></param>
/// <returns>true if the text buffer ends with any antiprompt</returns>
public bool Add(string text)
{
_string += text;

// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
var maxLength = Math.Max(32, _longestAntiprompt * 4);
var trimLength = Math.Max(16, _longestAntiprompt * 2);
if (_string.Length > maxLength)
_string = _string.Substring(_string.Length - trimLength);
// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
var maxLength = Math.Max(32, _longestAntiprompt * 4);
var trimLength = Math.Max(16, _longestAntiprompt * 2);
if (_string.Length > maxLength)
_string = _string.Substring(_string.Length - trimLength);

foreach (var antiprompt in _antiprompts)
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
return true;
foreach (var antiprompt in _antiprompts)
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
return true;

return false;
return false;
}
}
}

+ 356
- 0
LLama/Native/NativeApi.Load.cs View File

@@ -0,0 +1,356 @@
using LLama.Exceptions;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;

namespace LLama.Native
{
public partial class NativeApi
{
static NativeApi()
{
// Try to load a preferred library, based on CPU feature detection
TryLoadLibrary();

try
{
llama_empty_call();
}
catch (DllNotFoundException)
{
throw new RuntimeError("The native library cannot be found. It could be one of the following reasons: \n" +
"1. No LLamaSharp backend was installed. Please search LLamaSharp.Backend and install one of them. \n" +
"2. You are using a device with only CPU but installed cuda backend. Please install cpu backend instead. \n" +
"3. The backend is not compatible with your system cuda environment. Please check and fix it. If the environment is " +
"expected not to be changed, then consider build llama.cpp from source or submit an issue to LLamaSharp.\n" +
"4. One of the dependency of the native library is missed.\n");
}
llama_backend_init(false);
}

private static void Log(string message, LogLevel level)
{
if (!enableLogging) return;
Debug.Assert(level is LogLevel.Information or LogLevel.Error or LogLevel.Warning);
ConsoleColor color;
string levelPrefix;
if (level == LogLevel.Information)
{
color = ConsoleColor.Green;
levelPrefix = "[Info]";
}
else if (level == LogLevel.Error)
{
color = ConsoleColor.Red;
levelPrefix = "[Error]";
}
else
{
color = ConsoleColor.Yellow;
levelPrefix = "[Error]";
}
Console.ForegroundColor = color;
Console.WriteLine($"{loggingPrefix} {levelPrefix} {message}");
Console.ResetColor();
}

private static int GetCudaMajorVersion()
{
string? cudaPath;
string version = "";
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
if (cudaPath is null)
{
return -1;
}
version = GetCudaVersionFromPath(cudaPath);
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
// Try the default first
cudaPath = "/usr/local/bin/cuda";
version = GetCudaVersionFromPath(cudaPath);
if (string.IsNullOrEmpty(version))
{
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
if (cudaPath is null)
{
return -1;
}
foreach (var path in cudaPath.Split(':'))
{
version = GetCudaVersionFromPath(Path.Combine(path, ".."));
if (string.IsNullOrEmpty(version))
{
break;
}
}
}
}

if (string.IsNullOrEmpty(version))
{
return -1;
}
else
{
version = version.Split('.')[0];
bool success = int.TryParse(version, out var majorVersion);
if (success)
{
return majorVersion;
}
else
{
return -1;
}
}
}

private static string GetCudaVersionFromPath(string cudaPath)
{
try
{
string json = File.ReadAllText(Path.Combine(cudaPath, cudaVersionFile));
using (JsonDocument document = JsonDocument.Parse(json))
{
JsonElement root = document.RootElement;
JsonElement cublasNode = root.GetProperty("libcublas");
JsonElement versionNode = cublasNode.GetProperty("version");
if (versionNode.ValueKind == JsonValueKind.Undefined)
{
return string.Empty;
}
return versionNode.GetString();
}
}
catch (Exception)
{
return string.Empty;
}
}

#if NET6_0_OR_GREATER
private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix)
{
var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel);
if (!string.IsNullOrEmpty(avxStr))
{
avxStr += "/";
}
return $"{prefix}{avxStr}{libraryName}{suffix}";
}

private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration)
{
OSPlatform platform;
string prefix, suffix;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
platform = OSPlatform.Windows;
prefix = "runtimes/win-x64/native/";
suffix = ".dll";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
platform = OSPlatform.Linux;
prefix = "runtimes/linux-x64/native/";
suffix = ".so";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
platform = OSPlatform.OSX;
suffix = ".dylib";
if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
{
prefix = "runtimes/osx-arm64/native/";
}
else
{
prefix = "runtimes/osx-x64/native/";
}
}
else
{
throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp.");
}
Log($"Detected OS Platform: {platform}", LogLevel.Information);

List<string> result = new();
if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos
{
int cudaVersion = GetCudaMajorVersion();

// TODO: load cuda library with avx
if (cudaVersion == -1 && !configuration.AllowFallback)
{
// if check skipped, we just try to load cuda libraries one by one.
if (configuration.SkipCheck)
{
result.Add($"{prefix}cuda12/{libraryName}{suffix}");
result.Add($"{prefix}cuda11/{libraryName}{suffix}");
}
else
{
throw new RuntimeError("Configured to load a cuda library but no cuda detected on your device.");
}
}
else if (cudaVersion == 11)
{
Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information);
result.Add($"{prefix}cuda11/{libraryName}{suffix}");
}
else if (cudaVersion == 12)
{
Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information);
result.Add($"{prefix}cuda12/{libraryName}{suffix}");
}
else if (cudaVersion > 0)
{
throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it.");
}
// otherwise no cuda detected but allow fallback
}

// use cpu (or mac possibly with metal)
if (!configuration.AllowFallback && platform != OSPlatform.OSX)
{
result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix));
}
else if (platform != OSPlatform.OSX) // in macos there's absolutely no avx
{
#if NET8_0_OR_GREATER
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx512)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)));
}
else
#endif
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx2)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}
else if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix));
}

if (platform == OSPlatform.OSX)
{
result.Add($"{prefix}{libraryName}{suffix}");
}

return result;
}
#endif

/// <summary>
/// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible
/// </summary>
/// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns>
private static IntPtr TryLoadLibrary()
{
#if NET6_0_OR_GREATER
var configuration = NativeLibraryConfig.CheckAndGatherDescription();
enableLogging = configuration.Logging;
// We move the flag to avoid loading library when the variable is called else where.
NativeLibraryConfig.LibraryHasLoaded = true;

if (!string.IsNullOrEmpty(configuration.Path))
{
// When loading the user specified library, there's no fallback.
var success = NativeLibrary.TryLoad(configuration.Path, out var result);
if (!success)
{
throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified.");
}
Log($"Successfully loaded the library [{configuration.Path}] specified by user", LogLevel.Information);
return result;
}

var libraryTryLoadOrder = GetLibraryTryOrder(configuration);

string[] possiblePathPrefix = new string[] {
System.AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
};

var tryFindPath = (string filename) =>
{
int i = 0;
while (!File.Exists(filename))
{
if (i < possiblePathPrefix.Length)
{
filename = Path.Combine(possiblePathPrefix[i], filename);
i++;
}
else
{
break;
}
}
return filename;
};

foreach (var libraryPath in libraryTryLoadOrder)
{
var fullPath = tryFindPath(libraryPath);
var result = TryLoad(fullPath, true);
if (result is not null && result != IntPtr.Zero)
{
Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information);
return result ?? IntPtr.Zero;
}
else
{
Log($"Tried to load {fullPath} but failed.", LogLevel.Information);
}
}

if (!configuration.AllowFallback)
{
throw new RuntimeError("Failed to load the library that match your rule, please" +
" 1) check your rule." +
" 2) try to allow fallback." +
" 3) or open an issue if it's expected to be successful.");
}
#endif

Log($"No library was loaded before calling native apis. " +
$"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LogLevel.Warning);
return IntPtr.Zero;

#if NET6_0_OR_GREATER
// Try to load a DLL from the path if supported. Returns null if nothing is loaded.
static IntPtr? TryLoad(string path, bool supported = true)
{
if (!supported)
return null;

if (NativeLibrary.TryLoad(path, out var handle))
return handle;

return null;
}
#endif
}

private const string libraryName = "libllama";
private const string cudaVersionFile = "version.json";
private const string loggingPrefix = "[LLamaSharp Native]";
private static bool enableLogging = false;
}
}

+ 0
- 310
LLama/Native/NativeApi.cs View File

@@ -1,11 +1,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
using LLama.Exceptions;

#pragma warning disable IDE1006 // Naming Styles

@@ -25,312 +21,6 @@ namespace LLama.Native
/// </summary>
public unsafe partial class NativeApi
{
static NativeApi()
{
// Try to load a preferred library, based on CPU feature detection
TryLoadLibrary();

try
{
llama_empty_call();
}
catch (DllNotFoundException)
{
throw new RuntimeError("The native library cannot be found. It could be one of the following reasons: \n" +
"1. No LLamaSharp backend was installed. Please search LLamaSharp.Backend and install one of them. \n" +
"2. You are using a device with only CPU but installed cuda backend. Please install cpu backend instead. \n" +
"3. The backend is not compatible with your system cuda environment. Please check and fix it. If the environment is " +
"expected not to be changed, then consider build llama.cpp from source or submit an issue to LLamaSharp.\n" +
"4. One of the dependency of the native library is missed.\n");
}
llama_backend_init(false);
}

private static int GetCudaMajorVersion()
{
string? cudaPath;
string version = "";
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
if(cudaPath is null)
{
return -1;
}
version = GetCudaVersionFromPath(cudaPath);
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
// Try the default first
cudaPath = "/usr/local/bin/cuda";
version = GetCudaVersionFromPath(cudaPath);
if (string.IsNullOrEmpty(version))
{
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
if(cudaPath is null)
{
return -1;
}
foreach(var path in cudaPath.Split(':'))
{
version = GetCudaVersionFromPath(Path.Combine(path, ".."));
if (string.IsNullOrEmpty(version))
{
break;
}
}
}
}

if (string.IsNullOrEmpty(version))
{
return -1;
}
else
{
version = version.Split('.')[0];
bool success = int.TryParse(version, out var majorVersion);
if (success)
{
return majorVersion;
}
else
{
return -1;
}
}
}

private static string GetCudaVersionFromPath(string cudaPath)
{
try
{
string json = File.ReadAllText(Path.Combine(cudaPath, cudaVersionFile));
using (JsonDocument document = JsonDocument.Parse(json))
{
JsonElement root = document.RootElement;
JsonElement cublasNode = root.GetProperty("libcublas");
JsonElement versionNode = cublasNode.GetProperty("version");
if (versionNode.ValueKind == JsonValueKind.Undefined)
{
return string.Empty;
}
return versionNode.GetString();
}
}
catch (Exception)
{
return string.Empty;
}
}

#if NET6_0_OR_GREATER
private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix)
{
var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel);
if (!string.IsNullOrEmpty(avxStr))
{
avxStr += "/";
}
return $"{prefix}{avxStr}{libraryName}{suffix}";
}

private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration)
{
OSPlatform platform;
string prefix, suffix;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
platform = OSPlatform.Windows;
prefix = "runtimes/win-x64/native/";
suffix = ".dll";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
platform = OSPlatform.Linux;
prefix = "runtimes/linux-x64/native/";
suffix = ".so";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
platform = OSPlatform.OSX;
suffix = ".dylib";
if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
{
prefix = "runtimes/osx-arm64/native/";
}
else
{
prefix = "runtimes/osx-x64/native/";
}
}
else
{
throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp.");
}

List<string> result = new();
if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos
{
int cudaVersion = GetCudaMajorVersion();

// TODO: load cuda library with avx
if (cudaVersion == -1 && !configuration.AllowFallback)
{
// if check skipped, we just try to load cuda libraries one by one.
if (configuration.SkipCheck)
{
result.Add($"{prefix}cuda12/{libraryName}{suffix}");
result.Add($"{prefix}cuda11/{libraryName}{suffix}");
}
else
{
throw new RuntimeError("Configured to load a cuda library but no cuda detected on your device.");
}
}
else if (cudaVersion == 11)
{
result.Add($"{prefix}cuda11/{libraryName}{suffix}");
}
else if (cudaVersion == 12)
{
result.Add($"{prefix}cuda12/{libraryName}{suffix}");
}
else if (cudaVersion > 0)
{
throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it.");
}
// otherwise no cuda detected but allow fallback
}

// use cpu (or mac possibly with metal)
if (!configuration.AllowFallback && platform != OSPlatform.OSX)
{
result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix));
}
else if(platform != OSPlatform.OSX) // in macos there's absolutely no avx
{
#if NET8_0_OR_GREATER
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx512)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)));
}
else
#endif
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx2)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}
else if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix));
}
if(platform == OSPlatform.OSX)
{
result.Add($"{prefix}{libraryName}{suffix}");
}

return result;
}
#endif

/// <summary>
/// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible
/// </summary>
/// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns>
private static IntPtr TryLoadLibrary()
{
#if NET6_0_OR_GREATER
var configuration = NativeLibraryConfig.GetInstance().Desc;

if (!string.IsNullOrEmpty(configuration.Path))
{
// When loading the user specified library, there's no fallback.
var result = TryLoad(configuration.Path, true);
if (result is null || result == IntPtr.Zero)
{
throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified.");
}
return result ?? IntPtr.Zero;
}

var libraryTryLoadOrder = GetLibraryTryOrder(configuration);

string[] possiblePathPrefix = new string[] {
System.AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
};

var tryFindPath = (string filename) =>
{
int i = 0;
while (!File.Exists(filename))
{
if (i < possiblePathPrefix.Length)
{
filename = Path.Combine(possiblePathPrefix[i], filename);
i++;
}
else
{
break;
}
}
return filename;
};

foreach (var libraryPath in libraryTryLoadOrder)
{
var fullPath = tryFindPath(libraryPath);
var result = TryLoad(fullPath, true);
if(result is not null && result != IntPtr.Zero)
{
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine($"[Native Library] {fullPath} is loaded.");
Console.ResetColor();
return result ?? IntPtr.Zero;
}
else
{
Console.WriteLine($"Tried to load {fullPath}");
}
}

if (!configuration.AllowFallback)
{
throw new RuntimeError("Failed to load the library that match your rule, please" +
" 1) check your rule." +
" 2) try to allow fallback." +
" 3) or open an issue if it's expected to be successful.");
}
#endif

return IntPtr.Zero;

#if NET6_0_OR_GREATER
// Try to load a DLL from the path if supported. Returns null if nothing is loaded.
static IntPtr? TryLoad(string path, bool supported = true)
{
if (!supported)
return null;

if (NativeLibrary.TryLoad(path, out var handle))
return handle;

return null;
}
#endif
}

private const string libraryName = "libllama";
private const string cudaVersionFile = "version.json";

/// <summary>
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
/// </summary>


+ 201
- 0
LLama/Native/NativeLibraryConfig.cs View File

@@ -0,0 +1,201 @@
using System;

namespace LLama.Native
{
#if NET6_0_OR_GREATER
/// <summary>
/// A class about configurations when loading native libraries.
/// Note that it could be configured only once before any call to llama model apis.
/// </summary>
public class NativeLibraryConfig
{
private static NativeLibraryConfig? instance;
private static readonly object lockObject = new object();
public static NativeLibraryConfig Default
{
get
{
return GetInstance();
}
}

/// <summary>
/// Whether there's already a config for native library.
/// </summary>
public static bool LibraryHasLoaded { get; internal set; } = false;

private string _libraryPath;
private bool _useCuda;
private AvxLevel _avxLevel;
private bool _allowFallback;
private bool _skipCheck;
private bool _logging;

internal static NativeLibraryConfig GetInstance()
{
if (instance is null)
{
lock (lockObject)
{
if (instance is null)
{
instance = new NativeLibraryConfig();
}
}
}
return instance;
}

/// <summary>
/// Load a specified native library as backend for LLamaSharp.
/// When this method is called, all the other configurations will be ignored.
/// </summary>
/// <param name="libraryPath"></param>
/// <exception cref="InvalidOperationException"></exception>
public NativeLibraryConfig WithLibrary(string libraryPath)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
_libraryPath = libraryPath;
return this;
}

/// <summary>
/// Configure whether to use cuda backend if possible.
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public NativeLibraryConfig WithCuda(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
_useCuda = enable;
return this;
}

/// <summary>
/// Configure the prefferred avx support level of the backend.
/// </summary>
/// <param name="level"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public NativeLibraryConfig WithAvx(AvxLevel level)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
_avxLevel = level;
return this;
}

/// <summary>
/// Configure whether to allow fallback when there's not match for preffered settings.
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public NativeLibraryConfig WithAutoFallback(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
_allowFallback = enable;
return this;
}

/// <summary>
/// Whether to skip the check when you don't allow fallback. This option
/// may be useful under some complex conditions. For example, you're sure
/// you have your cublas configured but LLamaSharp take it as invalid by mistake.
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public NativeLibraryConfig SkipCheck(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
_skipCheck = enable;
return this;
}

/// <summary>
/// Whether to output the logs to console when loading the native library with your configuration.
/// </summary>
/// <param name="enable"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public NativeLibraryConfig WithLogs(bool enable = true)
{
if (LibraryHasLoaded)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
_logging = enable;
return this;
}

internal static Description CheckAndGatherDescription()
{
if (Default._allowFallback && Default._skipCheck)
{
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
}
return new Description(Default._libraryPath, Default._useCuda, Default._avxLevel, Default._allowFallback, Default._skipCheck, Default._logging);
}

internal static string AvxLevelToString(AvxLevel level)
{
return level switch
{
AvxLevel.None => string.Empty,
AvxLevel.Avx => "avx",
AvxLevel.Avx2 => "avx2",
#if NET8_0_OR_GREATER
AvxLevel.Avx512 => "avx512"
#endif
_ => throw new ArgumentException($"Cannot recognize Avx level {level}")
};
}


private NativeLibraryConfig()
{
_libraryPath = string.Empty;
_useCuda = true;
_avxLevel = AvxLevel.Avx2;
_allowFallback = true;
_skipCheck = false;
_logging = false;
}

/// <summary>
/// Avx support configuration
/// </summary>
public enum AvxLevel
{
/// <inheritdoc />
None = 0,
/// <inheritdoc />
Avx = 1,
/// <inheritdoc />
Avx2 = 2,
#if NET8_0_OR_GREATER
/// <inheritdoc />
Avx512 = 3,
#endif
}
internal record Description(string Path = "", bool UseCuda = true, AvxLevel AvxLevel = AvxLevel.Avx2,
bool AllowFallback = true, bool SkipCheck = false, bool Logging = false);
}
#endif
}

+ 0
- 113
LLama/NativeLibraryConfig.cs View File

@@ -1,113 +0,0 @@
using System;

namespace LLama
{
#if NET6_0_OR_GREATER
/// <summary>
/// A class about configurations when loading native libraries.
/// Note that it could be configured only once before any call to llama model apis.
/// </summary>
public class NativeLibraryConfig
{
private static NativeLibraryConfig? instance;
private static readonly object lockObject = new object();

/// <summary>
/// Whether there's already a config for native library.
/// </summary>
public bool Initialied { get; private set; }
internal Description Desc { get; private set; }

internal static NativeLibraryConfig GetInstance()
{
if (instance is null)
{
lock (lockObject)
{
if (instance is null)
{
instance = new NativeLibraryConfig();
}
}
}
return instance;
}

/// <summary>
/// Load a specified native library as backend for LLamaSharp
/// </summary>
/// <param name="libraryPath"></param>
/// <exception cref="InvalidOperationException"></exception>
public static void WithLibrary(string libraryPath)
{
var config = GetInstance();
if (config.Initialied)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
config.Desc = new Description(libraryPath);
}

/// <summary>
/// Ass rules to match a suitable library from installed LLamaSharp backend.
/// </summary>
/// <param name="useCuda"></param>
/// <param name="avxLevel"></param>
/// <param name="allowFallback">Whether to allow fall-back when your hardware doesn't support your configuration.</param>
/// <param name="skipCheck">Whether to skip the check when fallback is allowed.
/// It's especially useful when your cuda library is not in the default path. </param>
/// <exception cref="InvalidOperationException"></exception>
public static void WithMatchRule(bool useCuda = true, AvxLevel avxLevel = AvxLevel.Avx2, bool allowFallback = true, bool skipCheck = false)
{
if(allowFallback && skipCheck)
{
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
}
var config = GetInstance();
if (config.Initialied)
{
throw new InvalidOperationException("NativeLibraryConfig could be configured only once before any call to llama model apis.");
}
config.Desc = new Description(UseCuda: useCuda, AvxLevel: avxLevel, AllowFallback: allowFallback, SkipCheck: skipCheck);
}

internal static string AvxLevelToString(AvxLevel level)
{
return level switch
{
AvxLevel.None => string.Empty,
AvxLevel.Avx => "avx",
AvxLevel.Avx2 => "avx2",
#if NET8_0_OR_GREATER
AvxLevel.Avx512 => "avx512"
#endif
_ => throw new ArgumentException($"Cannot recognize Avx level {level}")
};
}


private NativeLibraryConfig()
{
Desc = new Description();
}

/// <summary>
/// Avx support configuration
/// </summary>
public enum AvxLevel
{
/// <inheritdoc />
None = 0,
/// <inheritdoc />
Avx = 1,
/// <inheritdoc />
Avx2 = 2,
#if NET8_0_OR_GREATER
/// <inheritdoc />
Avx512 = 3,
#endif
}
internal record Description(string Path = "", bool UseCuda = true, AvxLevel AvxLevel = AvxLevel.Avx2, bool AllowFallback = true, bool SkipCheck = false);
}
#endif
}

+ 145
- 144
LLama/StreamingTokenDecoder.cs View File

@@ -6,169 +6,170 @@ using System.Text;
using LLama.Extensions;
using LLama.Native;

namespace LLama;

/// <summary>
/// Decodes a stream of tokens into a stream of characters
/// </summary>
public sealed class StreamingTokenDecoder
namespace LLama
{
private readonly SafeLlamaModelHandle _weights;
private readonly Decoder _decoder;

private readonly List<char> _characters = new();

/// <summary>
/// The number of decoded characters waiting to be read
/// </summary>
public int AvailableCharacters => _characters.Count;

#region constructors
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Model weights</param>
public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
: this(encoding, weights.NativeHandle)
{
}

/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="context">Context to retrieve encoding and model weights from</param>
public StreamingTokenDecoder(LLamaContext context)
: this(context.Encoding, context.NativeHandle)
{
}

/// <summary>
/// Create a new decoder
/// Decodes a stream of tokens into a stream of characters
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="context">Context to retrieve model weights from</param>
public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
: this(encoding, context.ModelHandle)
public sealed class StreamingTokenDecoder
{
}

/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Models weights to use</param>
public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
{
_weights = weights;
_decoder = encoding.GetDecoder();
}
#endregion

/// <summary>
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(int token)
{
var charsArr = ArrayPool<char>.Shared.Rent(16);
var bytesArr = ArrayPool<byte>.Shared.Rent(16);
try
private readonly SafeLlamaModelHandle _weights;
private readonly Decoder _decoder;

private readonly List<char> _characters = new();

/// <summary>
/// The number of decoded characters waiting to be read
/// </summary>
public int AvailableCharacters => _characters.Count;

#region constructors
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Model weights</param>
public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
: this(encoding, weights.NativeHandle)
{
// Convert this token into bytes
var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;

// Convert those bytes into characters
var bytesOffset = 0;
var completed = false;
while (!completed)
{
// Decode some of the bytes into the temp char buffer. Keep doing this
// until all bytes have been consumed
_decoder.Convert(
bytesArr, bytesOffset, bytesAvailable,
charsArr, 0, charsArr.Length,
false,
out var bytesUsed, out var charsUsed, out completed
);
bytesOffset += bytesUsed;
bytesAvailable -= bytesUsed;

// Add the decoded characters to the output buffer
_characters.AddSpan(charsArr.AsSpan(0, charsUsed));
}
}
finally

/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="context">Context to retrieve encoding and model weights from</param>
public StreamingTokenDecoder(LLamaContext context)
: this(context.Encoding, context.NativeHandle)
{
ArrayPool<char>.Shared.Return(charsArr);
ArrayPool<byte>.Shared.Return(bytesArr);
}

return;
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="context">Context to retrieve model weights from</param>
public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
: this(encoding, context.ModelHandle)
{
}

// Converts a single token into bytes, using the `bytes` array as temporary storage.
// If the `bytes` array is too small it will get a larger one from the ArrayPool.
static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Models weights to use</param>
public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
{
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
_weights = weights;
_decoder = encoding.GetDecoder();
}
#endregion

// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
/// <summary>
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(int token)
{
var charsArr = ArrayPool<char>.Shared.Rent(16);
var bytesArr = ArrayPool<byte>.Shared.Rent(16);
try
{
// Return the old array to the pool and get a new one
ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);

// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
// Convert this token into bytes
var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;

// Convert those bytes into characters
var bytesOffset = 0;
var completed = false;
while (!completed)
{
// Decode some of the bytes into the temp char buffer. Keep doing this
// until all bytes have been consumed
_decoder.Convert(
bytesArr, bytesOffset, bytesAvailable,
charsArr, 0, charsArr.Length,
false,
out var bytesUsed, out var charsUsed, out completed
);
bytesOffset += bytesUsed;
bytesAvailable -= bytesUsed;

// Add the decoded characters to the output buffer
_characters.AddSpan(charsArr.AsSpan(0, charsUsed));
}
}
finally
{
ArrayPool<char>.Shared.Return(charsArr);
ArrayPool<byte>.Shared.Return(bytesArr);
}

Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
return;

// Converts a single token into bytes, using the `bytes` array as temporary storage.
// If the `bytes` array is too small it will get a larger one from the ArrayPool.
static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
{
// Try to get bytes
var l = model.TokenToSpan(token, bytes);

// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
{
// Return the old array to the pool and get a new one
ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);

// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}

Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
}
}
}

/// <summary>
/// Add all tokens in the given enumerable
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<int> tokens)
{
foreach (var item in tokens)
Add(item);
}
/// <summary>
/// Add all tokens in the given enumerable
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<int> tokens)
{
foreach (var item in tokens)
Add(item);
}

/// <summary>
/// Read all decoded characters and clear the buffer
/// </summary>
/// <param name="dest"></param>
public void Read(List<char> dest)
{
dest.AddRange(_characters);
_characters.Clear();
}
/// <summary>
/// Read all decoded characters and clear the buffer
/// </summary>
/// <param name="dest"></param>
public void Read(List<char> dest)
{
dest.AddRange(_characters);
_characters.Clear();
}

/// <summary>
/// Read all decoded characters as a string and clear the buffer
/// </summary>
/// <returns></returns>
public string Read()
{
if (_characters.Count == 0)
return "";
/// <summary>
/// Read all decoded characters as a string and clear the buffer
/// </summary>
/// <returns></returns>
public string Read()
{
if (_characters.Count == 0)
return "";

var str = string.Join("", _characters);
_characters.Clear();
return str;
}
var str = string.Join("", _characters);
_characters.Clear();
return str;
}

/// <summary>
/// Set the decoder back to its initial state
/// </summary>
public void Reset()
{
_decoder.Reset();
_characters.Clear();
/// <summary>
/// Set the decoder back to its initial state
/// </summary>
public void Reset()
{
_decoder.Reset();
_characters.Clear();
}
}
}
}

Loading…
Cancel
Save