You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Utils.cs 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. using LLama.Common;
  2. using LLama.Exceptions;
  3. using LLama.Native;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.IO;
  7. using System.Linq;
  8. using System.Runtime.InteropServices;
  9. using System.Text;
  10. namespace LLama
  11. {
  12. using llama_token = Int32;
  13. internal static class Utils
  14. {
  15. public static SafeLLamaContextHandle InitLLamaContextFromModelParams(ModelParams @params)
  16. {
  17. var lparams = NativeApi.llama_context_default_params();
  18. lparams.n_ctx = @params.ContextSize;
  19. lparams.n_batch = @params.BatchSize;
  20. lparams.main_gpu = @params.MainGpu;
  21. lparams.n_gpu_layers = @params.GpuLayerCount;
  22. lparams.seed = @params.Seed;
  23. lparams.f16_kv = @params.UseFp16Memory;
  24. lparams.use_mmap = @params.UseMemoryLock;
  25. lparams.use_mlock = @params.UseMemoryLock;
  26. lparams.logits_all = @params.Perplexity;
  27. lparams.embedding = @params.EmbeddingMode;
  28. lparams.low_vram = @params.LowVram;
  29. if (@params.TensorSplits.Length != 1)
  30. {
  31. throw new ArgumentException("Currently multi-gpu support is not supported by " +
  32. "both llama.cpp and LLamaSharp.");
  33. }
  34. lparams.tensor_split = @params.TensorSplits;
  35. if (!File.Exists(@params.ModelPath))
  36. {
  37. throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
  38. }
  39. var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
  40. var ctx = SafeLLamaContextHandle.Create(model, lparams);
  41. if (!string.IsNullOrEmpty(@params.LoraAdapter))
  42. model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
  43. return ctx;
  44. }
  45. public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
  46. {
  47. var cnt = encoding.GetByteCount(text);
  48. llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
  49. int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
  50. if (n < 0)
  51. {
  52. throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
  53. "specify the encoding.");
  54. }
  55. return res.Take(n);
  56. }
  57. public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
  58. {
  59. var logits = NativeApi.llama_get_logits(ctx);
  60. return new Span<float>(logits, length);
  61. }
  62. public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
  63. {
  64. int result;
  65. fixed(llama_token* p = tokens)
  66. {
  67. result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads);
  68. }
  69. return result;
  70. }
  71. public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
  72. {
  73. return PtrToString(NativeApi.llama_token_to_str(ctx, token), encoding);
  74. }
  75. public static unsafe string PtrToString(IntPtr ptr, Encoding encoding)
  76. {
  77. #if NET6_0_OR_GREATER
  78. if(encoding == Encoding.UTF8)
  79. {
  80. return Marshal.PtrToStringUTF8(ptr);
  81. }
  82. else if(encoding == Encoding.Unicode)
  83. {
  84. return Marshal.PtrToStringUni(ptr);
  85. }
  86. else
  87. {
  88. return Marshal.PtrToStringAuto(ptr);
  89. }
  90. #else
  91. byte* tp = (byte*)ptr.ToPointer();
  92. List<byte> bytes = new();
  93. while (true)
  94. {
  95. byte c = *tp++;
  96. if (c == '\0')
  97. {
  98. break;
  99. }
  100. else
  101. {
  102. bytes.Add(c);
  103. }
  104. }
  105. return encoding.GetString(bytes.ToArray());
  106. #endif
  107. }
  108. }
  109. }