You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NativeApi.Load.cs 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. using LLama.Exceptions;
  2. using Microsoft.Extensions.Logging;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.IO;
  7. using System.Linq;
  8. using System.Runtime.InteropServices;
  9. using System.Text.Json;
  10. using System.Text.RegularExpressions;
  11. namespace LLama.Native
  12. {
  13. public partial class NativeApi
  14. {
  15. static NativeApi()
  16. {
  17. // Try to load a preferred library, based on CPU feature detection
  18. TryLoadLibrary();
  19. try
  20. {
  21. llama_empty_call();
  22. }
  23. catch (DllNotFoundException)
  24. {
  25. throw new RuntimeError("The native library cannot be correctly loaded. It could be one of the following reasons: \n" +
  26. "1. No LLamaSharp backend was installed. Please search LLamaSharp.Backend and install one of them. \n" +
  27. "2. You are using a device with only CPU but installed cuda backend. Please install cpu backend instead. \n" +
  28. "3. One of the dependency of the native library is missed. Please use `ldd` on linux, `dumpbin` on windows and `otool`" +
  29. "to check if all the dependency of the native library is satisfied. Generally you could find the libraries under your output folder.\n" +
  30. "4. Try to compile llama.cpp yourself to generate a libllama library, then use `LLama.Native.NativeLibraryConfig.WithLibrary` " +
  31. "to specify it at the very beginning of your code. For more informations about compilation, please refer to LLamaSharp repo on github.\n");
  32. }
  33. llama_backend_init(false);
  34. }
  35. private static void Log(string message, LogLevel level)
  36. {
  37. if (!enableLogging) return;
  38. Debug.Assert(level is LogLevel.Information or LogLevel.Error or LogLevel.Warning);
  39. ConsoleColor color;
  40. string levelPrefix;
  41. if (level == LogLevel.Information)
  42. {
  43. color = ConsoleColor.Green;
  44. levelPrefix = "[Info]";
  45. }
  46. else if (level == LogLevel.Error)
  47. {
  48. color = ConsoleColor.Red;
  49. levelPrefix = "[Error]";
  50. }
  51. else
  52. {
  53. color = ConsoleColor.Yellow;
  54. levelPrefix = "[Error]";
  55. }
  56. Console.ForegroundColor = color;
  57. Console.WriteLine($"{loggingPrefix} {levelPrefix} {message}");
  58. Console.ResetColor();
  59. }
  60. private static int GetCudaMajorVersion()
  61. {
  62. string? cudaPath;
  63. string version = "";
  64. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  65. {
  66. cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
  67. if (cudaPath is null)
  68. {
  69. version = GetCudaVersionFromDriverUtils_windows();
  70. }
  71. else
  72. {
  73. version = GetCudaVersionFromPath(cudaPath);
  74. }
  75. }
  76. else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
  77. {
  78. // Try the default first
  79. cudaPath = "/usr/local/bin/cuda";
  80. version = GetCudaVersionFromPath(cudaPath);
  81. if (string.IsNullOrEmpty(version))
  82. {
  83. cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
  84. if (cudaPath is null)
  85. {
  86. return -1;
  87. }
  88. foreach (var path in cudaPath.Split(':'))
  89. {
  90. version = GetCudaVersionFromPath(Path.Combine(path, ".."));
  91. if (string.IsNullOrEmpty(version))
  92. {
  93. break;
  94. }
  95. }
  96. }
  97. }
  98. if (string.IsNullOrEmpty(version))
  99. {
  100. return -1;
  101. }
  102. else
  103. {
  104. version = version.Split('.')[0];
  105. bool success = int.TryParse(version, out var majorVersion);
  106. if (success)
  107. {
  108. return majorVersion;
  109. }
  110. else
  111. {
  112. return -1;
  113. }
  114. }
  115. }
  116. private static string GetCudaVersionFromDriverUtils_windows()
  117. {
  118. try
  119. {
  120. var psi = new ProcessStartInfo
  121. {
  122. FileName = "nvidia-smi",
  123. RedirectStandardOutput = true,
  124. UseShellExecute = false,
  125. CreateNoWindow = true
  126. };
  127. using (var process = Process.Start(psi))
  128. {
  129. if (process != null)
  130. {
  131. using (StreamReader reader = process.StandardOutput)
  132. {
  133. string output = reader.ReadToEnd();
  134. process.WaitForExit();
  135. string cudaVersion = GetNvidiaSmiValue(output, "CUDA Version");
  136. string pattern = @":\s(\d+\.\d+)";
  137. Match match = Regex.Match(cudaVersion, pattern);
  138. string extractedValue = string.Empty;
  139. if (match.Success && match.Groups.Count > 1)
  140. {
  141. extractedValue = match.Groups[1].Value;
  142. }
  143. return extractedValue;
  144. }
  145. }
  146. else
  147. {
  148. return string.Empty;
  149. }
  150. }
  151. }
  152. catch (Exception)
  153. {
  154. return string.Empty;
  155. }
  156. }
  157. static string GetNvidiaSmiValue(string nvidiaSmiOutput, string key)
  158. {
  159. int startIndex = nvidiaSmiOutput.IndexOf(key);
  160. if (startIndex == -1)
  161. {
  162. return "N/A";
  163. }
  164. startIndex += key.Length;
  165. int endIndex = nvidiaSmiOutput.IndexOf('\n', startIndex);
  166. if (endIndex == -1)
  167. {
  168. endIndex = nvidiaSmiOutput.Length;
  169. }
  170. string value = nvidiaSmiOutput.Substring(startIndex, endIndex - startIndex).Trim();
  171. return value;
  172. }
  173. private static string GetCudaVersionFromPath(string cudaPath)
  174. {
  175. try
  176. {
  177. string json = File.ReadAllText(Path.Combine(cudaPath, cudaVersionFile));
  178. using (JsonDocument document = JsonDocument.Parse(json))
  179. {
  180. JsonElement root = document.RootElement;
  181. JsonElement cublasNode = root.GetProperty("libcublas");
  182. JsonElement versionNode = cublasNode.GetProperty("version");
  183. if (versionNode.ValueKind == JsonValueKind.Undefined)
  184. {
  185. return string.Empty;
  186. }
  187. return versionNode.GetString();
  188. }
  189. }
  190. catch (Exception)
  191. {
  192. return string.Empty;
  193. }
  194. }
  195. #if NET6_0_OR_GREATER
  196. private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix)
  197. {
  198. var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel);
  199. if (!string.IsNullOrEmpty(avxStr))
  200. {
  201. avxStr += "/";
  202. }
  203. return $"{prefix}{avxStr}{libraryName}{suffix}";
  204. }
  205. private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration)
  206. {
  207. OSPlatform platform;
  208. string prefix, suffix;
  209. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  210. {
  211. platform = OSPlatform.Windows;
  212. prefix = "runtimes/win-x64/native/";
  213. suffix = ".dll";
  214. }
  215. else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
  216. {
  217. platform = OSPlatform.Linux;
  218. prefix = "runtimes/linux-x64/native/";
  219. suffix = ".so";
  220. }
  221. else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  222. {
  223. platform = OSPlatform.OSX;
  224. suffix = ".dylib";
  225. if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
  226. {
  227. prefix = "runtimes/osx-arm64/native/";
  228. }
  229. else
  230. {
  231. prefix = "runtimes/osx-x64/native/";
  232. }
  233. }
  234. else
  235. {
  236. throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp.");
  237. }
  238. Log($"Detected OS Platform: {platform}", LogLevel.Information);
  239. List<string> result = new();
  240. if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos
  241. {
  242. int cudaVersion = GetCudaMajorVersion();
  243. // TODO: load cuda library with avx
  244. if (cudaVersion == -1 && !configuration.AllowFallback)
  245. {
  246. // if check skipped, we just try to load cuda libraries one by one.
  247. if (configuration.SkipCheck)
  248. {
  249. result.Add($"{prefix}cuda12/{libraryName}{suffix}");
  250. result.Add($"{prefix}cuda11/{libraryName}{suffix}");
  251. }
  252. else
  253. {
  254. throw new RuntimeError("Configured to load a cuda library but no cuda detected on your device.");
  255. }
  256. }
  257. else if (cudaVersion == 11)
  258. {
  259. Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information);
  260. result.Add($"{prefix}cuda11/{libraryName}{suffix}");
  261. }
  262. else if (cudaVersion == 12)
  263. {
  264. Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information);
  265. result.Add($"{prefix}cuda12/{libraryName}{suffix}");
  266. }
  267. else if (cudaVersion > 0)
  268. {
  269. throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it.");
  270. }
  271. // otherwise no cuda detected but allow fallback
  272. }
  273. // use cpu (or mac possibly with metal)
  274. if (!configuration.AllowFallback && platform != OSPlatform.OSX)
  275. {
  276. result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix));
  277. }
  278. else if (platform != OSPlatform.OSX) // in macos there's absolutely no avx
  279. {
  280. if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx512)
  281. result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix));
  282. if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx2)
  283. result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix));
  284. if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx)
  285. result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
  286. result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix));
  287. }
  288. if (platform == OSPlatform.OSX)
  289. {
  290. result.Add($"{prefix}{libraryName}{suffix}");
  291. }
  292. return result;
  293. }
  294. #endif
  295. /// <summary>
  296. /// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible
  297. /// </summary>
  298. /// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns>
  299. private static IntPtr TryLoadLibrary()
  300. {
  301. #if NET6_0_OR_GREATER
  302. var configuration = NativeLibraryConfig.CheckAndGatherDescription();
  303. enableLogging = configuration.Logging;
  304. // We move the flag to avoid loading library when the variable is called else where.
  305. NativeLibraryConfig.LibraryHasLoaded = true;
  306. Log(configuration.ToString(), LogLevel.Information);
  307. if (!string.IsNullOrEmpty(configuration.Path))
  308. {
  309. // When loading the user specified library, there's no fallback.
  310. var success = NativeLibrary.TryLoad(configuration.Path, out var result);
  311. if (!success)
  312. {
  313. throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified.");
  314. }
  315. Log($"Successfully loaded the library [{configuration.Path}] specified by user", LogLevel.Information);
  316. return result;
  317. }
  318. var libraryTryLoadOrder = GetLibraryTryOrder(configuration);
  319. string[] preferredPaths = configuration.SearchDirectories;
  320. string[] possiblePathPrefix = new string[] {
  321. System.AppDomain.CurrentDomain.BaseDirectory,
  322. Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
  323. };
  324. var tryFindPath = (string filename) =>
  325. {
  326. foreach(var path in preferredPaths)
  327. {
  328. if (File.Exists(Path.Combine(path, filename)))
  329. {
  330. return Path.Combine(path, filename);
  331. }
  332. }
  333. foreach(var path in possiblePathPrefix)
  334. {
  335. if (File.Exists(Path.Combine(path, filename)))
  336. {
  337. return Path.Combine(path, filename);
  338. }
  339. }
  340. return filename;
  341. };
  342. foreach (var libraryPath in libraryTryLoadOrder)
  343. {
  344. var fullPath = tryFindPath(libraryPath);
  345. var result = TryLoad(fullPath, true);
  346. if (result is not null && result != IntPtr.Zero)
  347. {
  348. Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information);
  349. return result ?? IntPtr.Zero;
  350. }
  351. else
  352. {
  353. Log($"Tried to load {fullPath} but failed.", LogLevel.Information);
  354. }
  355. }
  356. if (!configuration.AllowFallback)
  357. {
  358. throw new RuntimeError("Failed to load the library that match your rule, please" +
  359. " 1) check your rule." +
  360. " 2) try to allow fallback." +
  361. " 3) or open an issue if it's expected to be successful.");
  362. }
  363. #endif
  364. Log($"No library was loaded before calling native apis. " +
  365. $"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LogLevel.Warning);
  366. return IntPtr.Zero;
  367. #if NET6_0_OR_GREATER
  368. // Try to load a DLL from the path if supported. Returns null if nothing is loaded.
  369. static IntPtr? TryLoad(string path, bool supported = true)
  370. {
  371. if (!supported)
  372. return null;
  373. if (NativeLibrary.TryLoad(path, out var handle))
  374. return handle;
  375. return null;
  376. }
  377. #endif
  378. }
  379. private const string libraryName = "libllama";
  380. private const string cudaVersionFile = "version.json";
  381. private const string loggingPrefix = "[LLamaSharp Native]";
  382. private static bool enableLogging = false;
  383. }
  384. }