You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NativeApi.Load.cs 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. using LLama.Exceptions;
  2. using System;
  3. using System.IO;
  4. using System.Runtime.InteropServices;
  5. using System.Text.Json;
  6. using System.Collections.Generic;
  7. namespace LLama.Native
  8. {
  9. public static partial class NativeApi
  10. {
  11. static NativeApi()
  12. {
  13. // Overwrite the Dll import resolver for this assembly. The resolver gets
  14. // called by the runtime every time that a call into a DLL is required. The
  15. // resolver returns the loaded DLL handle. This allows us to take control of
  16. // which llama.dll is used.
  17. SetDllImportResolver();
  18. // Set flag to indicate that this point has been passed. No native library config can be done after this point.
  19. NativeLibraryConfig.LibraryHasLoaded = true;
  20. // Immediately make a call which requires loading the llama DLL. This method call
  21. // can't fail unless the DLL hasn't been loaded.
  22. try
  23. {
  24. llama_empty_call();
  25. }
  26. catch (DllNotFoundException)
  27. {
  28. throw new RuntimeError("The native library cannot be correctly loaded. It could be one of the following reasons: \n" +
  29. "1. No LLamaSharp backend was installed. Please search LLamaSharp.Backend and install one of them. \n" +
  30. "2. You are using a device with only CPU but installed cuda backend. Please install cpu backend instead. \n" +
  31. "3. One of the dependency of the native library is missed. Please use `ldd` on linux, `dumpbin` on windows and `otool`" +
  32. "to check if all the dependency of the native library is satisfied. Generally you could find the libraries under your output folder.\n" +
  33. "4. Try to compile llama.cpp yourself to generate a libllama library, then use `LLama.Native.NativeLibraryConfig.WithLibrary` " +
  34. "to specify it at the very beginning of your code. For more informations about compilation, please refer to LLamaSharp repo on github.\n");
  35. }
  36. // Now that the "loaded" flag is set configure logging in llama.cpp
  37. if (NativeLibraryConfig.Instance.LogCallback != null)
  38. NativeLogConfig.llama_log_set(NativeLibraryConfig.Instance.LogCallback);
  39. // Init llama.cpp backend
  40. llama_backend_init();
  41. }
  42. #if NET5_0_OR_GREATER
  43. private static IntPtr _loadedLlamaHandle;
  44. private static IntPtr _loadedLlavaSharedHandle;
  45. #endif
  46. private static void SetDllImportResolver()
  47. {
  48. // NativeLibrary is not available on older runtimes. We'll have to depend on
  49. // the normal runtime dll resolution there.
  50. #if NET5_0_OR_GREATER
  51. NativeLibrary.SetDllImportResolver(typeof(NativeApi).Assembly, (name, _, _) =>
  52. {
  53. if (name == "llama")
  54. {
  55. // If we've already loaded llama return the handle that was loaded last time.
  56. if (_loadedLlamaHandle != IntPtr.Zero)
  57. return _loadedLlamaHandle;
  58. // Try to load a preferred library, based on CPU feature detection
  59. _loadedLlamaHandle = TryLoadLibraries(LibraryName.Llama);
  60. return _loadedLlamaHandle;
  61. }
  62. if (name == "llava_shared")
  63. {
  64. // If we've already loaded llava return the handle that was loaded last time.
  65. if (_loadedLlavaSharedHandle != IntPtr.Zero)
  66. return _loadedLlavaSharedHandle;
  67. // Try to load a preferred library, based on CPU feature detection
  68. _loadedLlavaSharedHandle = TryLoadLibraries(LibraryName.LlavaShared);
  69. return _loadedLlavaSharedHandle;
  70. }
  71. // Return null pointer to indicate that nothing was loaded.
  72. return IntPtr.Zero;
  73. });
  74. #endif
  75. }
  76. private static void Log(string message, LLamaLogLevel level)
  77. {
  78. if (!message.EndsWith("\n"))
  79. message += "\n";
  80. NativeLibraryConfig.Instance.LogCallback?.Invoke(level, message);
  81. }
  82. #region CUDA version
  83. private static int GetCudaMajorVersion()
  84. {
  85. string? cudaPath;
  86. string version = "";
  87. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  88. {
  89. cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
  90. if (cudaPath is null)
  91. {
  92. return -1;
  93. }
  94. //Ensuring cuda bin path is reachable. Especially for MAUI environment.
  95. string cudaBinPath = Path.Combine(cudaPath, "bin");
  96. if (Directory.Exists(cudaBinPath))
  97. {
  98. AddDllDirectory(cudaBinPath);
  99. }
  100. version = GetCudaVersionFromPath(cudaPath);
  101. }
  102. else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
  103. {
  104. // Try the default first
  105. cudaPath = "/usr/local/bin/cuda";
  106. version = GetCudaVersionFromPath(cudaPath);
  107. if (string.IsNullOrEmpty(version))
  108. {
  109. cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
  110. if (cudaPath is null)
  111. {
  112. return -1;
  113. }
  114. foreach (var path in cudaPath.Split(':'))
  115. {
  116. version = GetCudaVersionFromPath(Path.Combine(path, ".."));
  117. if (string.IsNullOrEmpty(version))
  118. {
  119. break;
  120. }
  121. }
  122. }
  123. }
  124. if (string.IsNullOrEmpty(version))
  125. return -1;
  126. version = version.Split('.')[0];
  127. if (int.TryParse(version, out var majorVersion))
  128. return majorVersion;
  129. return -1;
  130. }
  131. private static string GetCudaVersionFromPath(string cudaPath)
  132. {
  133. try
  134. {
  135. string json = File.ReadAllText(Path.Combine(cudaPath, cudaVersionFile));
  136. using (JsonDocument document = JsonDocument.Parse(json))
  137. {
  138. JsonElement root = document.RootElement;
  139. JsonElement cublasNode = root.GetProperty("libcublas");
  140. JsonElement versionNode = cublasNode.GetProperty("version");
  141. if (versionNode.ValueKind == JsonValueKind.Undefined)
  142. {
  143. return string.Empty;
  144. }
  145. return versionNode.GetString() ?? "";
  146. }
  147. }
  148. catch (Exception)
  149. {
  150. return string.Empty;
  151. }
  152. }
  153. #endregion
  154. #if NET6_0_OR_GREATER
  155. private static IEnumerable<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration)
  156. {
  157. var loadingName = configuration.Library.GetLibraryName();
  158. Log($"Loading library: '{loadingName}'", LLamaLogLevel.Debug);
  159. // Get platform specific parts of the path (e.g. .so/.dll/.dylib, libName prefix or not)
  160. GetPlatformPathParts(out var platform, out var os, out var ext, out var libPrefix);
  161. Log($"Detected OS Platform: '{platform}'", LLamaLogLevel.Info);
  162. Log($"Detected OS string: '{os}'", LLamaLogLevel.Debug);
  163. Log($"Detected extension string: '{ext}'", LLamaLogLevel.Debug);
  164. Log($"Detected prefix string: '{libPrefix}'", LLamaLogLevel.Debug);
  165. if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux))
  166. {
  167. var cudaVersion = GetCudaMajorVersion();
  168. Log($"Detected cuda major version {cudaVersion}.", LLamaLogLevel.Info);
  169. if (cudaVersion == -1 && !configuration.AllowFallback)
  170. {
  171. // if check skipped, we just try to load cuda libraries one by one.
  172. if (configuration.SkipCheck)
  173. {
  174. yield return GetCudaLibraryPath(loadingName, "cuda12");
  175. yield return GetCudaLibraryPath(loadingName, "cuda11");
  176. }
  177. else
  178. {
  179. throw new RuntimeError("Configured to load a cuda library but no cuda detected on your device.");
  180. }
  181. }
  182. else if (cudaVersion == 11)
  183. {
  184. yield return GetCudaLibraryPath(loadingName, "cuda11");
  185. }
  186. else if (cudaVersion == 12)
  187. {
  188. yield return GetCudaLibraryPath(loadingName, "cuda12");
  189. }
  190. else if (cudaVersion > 0)
  191. {
  192. throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it.");
  193. }
  194. // otherwise no cuda detected but allow fallback
  195. }
  196. // Add the CPU/Metal libraries
  197. if (platform == OSPlatform.OSX)
  198. {
  199. // On Mac it's very simple, there's no AVX to consider.
  200. yield return GetMacLibraryPath(loadingName);
  201. }
  202. else
  203. {
  204. if (configuration.AllowFallback)
  205. {
  206. // Try all of the AVX levels we can support.
  207. if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx512)
  208. yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx512);
  209. if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx2)
  210. yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx2);
  211. if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx)
  212. yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx);
  213. yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.None);
  214. }
  215. else
  216. {
  217. // Fallback is not allowed - use the exact specified AVX level
  218. yield return GetAvxLibraryPath(loadingName, configuration.AvxLevel);
  219. }
  220. }
  221. }
  222. private static string GetMacLibraryPath(string libraryName)
  223. {
  224. GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix);
  225. return $"runtimes/{os}/native/{libPrefix}{libraryName}{fileExtension}";
  226. }
  227. /// <summary>
  228. /// Given a CUDA version and some path parts, create a complete path to the library file
  229. /// </summary>
  230. /// <param name="libraryName">Library being loaded (e.g. "llama")</param>
  231. /// <param name="cuda">CUDA version (e.g. "cuda11")</param>
  232. /// <returns></returns>
  233. private static string GetCudaLibraryPath(string libraryName, string cuda)
  234. {
  235. GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix);
  236. return $"runtimes/{os}/native/{cuda}/{libPrefix}{libraryName}{fileExtension}";
  237. }
  238. /// <summary>
  239. /// Given an AVX level and some path parts, create a complete path to the library file
  240. /// </summary>
  241. /// <param name="libraryName">Library being loaded (e.g. "llama")</param>
  242. /// <param name="avx"></param>
  243. /// <returns></returns>
  244. private static string GetAvxLibraryPath(string libraryName, NativeLibraryConfig.AvxLevel avx)
  245. {
  246. GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix);
  247. var avxStr = NativeLibraryConfig.AvxLevelToString(avx);
  248. if (!string.IsNullOrEmpty(avxStr))
  249. avxStr += "/";
  250. return $"runtimes/{os}/native/{avxStr}{libPrefix}{libraryName}{fileExtension}";
  251. }
  252. private static void GetPlatformPathParts(out OSPlatform platform, out string os, out string fileExtension, out string libPrefix)
  253. {
  254. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  255. {
  256. platform = OSPlatform.Windows;
  257. os = "win-x64";
  258. fileExtension = ".dll";
  259. libPrefix = "";
  260. return;
  261. }
  262. if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
  263. {
  264. platform = OSPlatform.Linux;
  265. os = "linux-x64";
  266. fileExtension = ".so";
  267. libPrefix = "lib";
  268. return;
  269. }
  270. if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  271. {
  272. platform = OSPlatform.OSX;
  273. fileExtension = ".dylib";
  274. os = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported
  275. ? "osx-arm64"
  276. : "osx-x64";
  277. libPrefix = "lib";
  278. }
  279. else
  280. {
  281. throw new RuntimeError("Your operating system is not supported, please open an issue in LLamaSharp.");
  282. }
  283. }
  284. #endif
  285. /// <summary>
  286. /// Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible
  287. /// </summary>
  288. /// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns>
  289. private static IntPtr TryLoadLibraries(LibraryName lib)
  290. {
  291. #if NET6_0_OR_GREATER
  292. var configuration = NativeLibraryConfig.CheckAndGatherDescription(lib);
  293. // Set the flag to ensure the NativeLibraryConfig can no longer be modified
  294. NativeLibraryConfig.LibraryHasLoaded = true;
  295. // Show the configuration we're working with
  296. Log(configuration.ToString(), LLamaLogLevel.Info);
  297. // If a specific path is requested, load that or immediately fail
  298. if (!string.IsNullOrEmpty(configuration.Path))
  299. {
  300. if (!NativeLibrary.TryLoad(configuration.Path, out var handle))
  301. throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified.");
  302. Log($"Successfully loaded the library [{configuration.Path}] specified by user", LLamaLogLevel.Info);
  303. return handle;
  304. }
  305. // Get a list of locations to try loading (in order of preference)
  306. var libraryTryLoadOrder = GetLibraryTryOrder(configuration);
  307. foreach (var libraryPath in libraryTryLoadOrder)
  308. {
  309. var fullPath = TryFindPath(libraryPath);
  310. Log($"Trying '{fullPath}'", LLamaLogLevel.Debug);
  311. var result = TryLoad(fullPath);
  312. if (result != IntPtr.Zero)
  313. {
  314. Log($"Loaded '{fullPath}'", LLamaLogLevel.Info);
  315. return result;
  316. }
  317. Log($"Failed Loading '{fullPath}'", LLamaLogLevel.Info);
  318. }
  319. if (!configuration.AllowFallback)
  320. {
  321. throw new RuntimeError("Failed to load the library that match your rule, please" +
  322. " 1) check your rule." +
  323. " 2) try to allow fallback." +
  324. " 3) or open an issue if it's expected to be successful.");
  325. }
  326. #endif
  327. Log($"No library was loaded before calling native apis. " +
  328. $"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LLamaLogLevel.Warning);
  329. return IntPtr.Zero;
  330. #if NET6_0_OR_GREATER
  331. // Try to load a DLL from the path.
  332. // Returns null if nothing is loaded.
  333. static IntPtr TryLoad(string path)
  334. {
  335. if (NativeLibrary.TryLoad(path, out var handle))
  336. return handle;
  337. return IntPtr.Zero;
  338. }
  339. // Try to find the given file in any of the possible search paths
  340. string TryFindPath(string filename)
  341. {
  342. // Try the configured search directories in the configuration
  343. foreach (var path in configuration.SearchDirectories)
  344. {
  345. var candidate = Path.Combine(path, filename);
  346. if (File.Exists(candidate))
  347. return candidate;
  348. }
  349. // Try a few other possible paths
  350. var possiblePathPrefix = new[] {
  351. AppDomain.CurrentDomain.BaseDirectory,
  352. Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
  353. };
  354. foreach (var path in possiblePathPrefix)
  355. {
  356. var candidate = Path.Combine(path, filename);
  357. if (File.Exists(candidate))
  358. return candidate;
  359. }
  360. return filename;
  361. }
  362. #endif
  363. }
  364. internal const string libraryName = "llama";
  365. internal const string llavaLibraryName = "llava_shared";
  366. private const string cudaVersionFile = "version.json";
  367. }
  368. }