You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Utils.cs 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. using LLama.Abstractions;
  2. using LLama.Exceptions;
  3. using LLama.Native;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.IO;
  7. using System.Linq;
  8. using System.Runtime.InteropServices;
  9. using System.Text;
  10. namespace LLama
  11. {
  12. using llama_token = Int32;
  13. internal static class Utils
  14. {
  15. public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
  16. {
  17. var lparams = NativeApi.llama_context_default_params();
  18. lparams.n_ctx = @params.ContextSize;
  19. lparams.n_batch = @params.BatchSize;
  20. lparams.main_gpu = @params.MainGpu;
  21. lparams.n_gpu_layers = @params.GpuLayerCount;
  22. lparams.seed = @params.Seed;
  23. lparams.f16_kv = @params.UseFp16Memory;
  24. lparams.use_mmap = @params.UseMemoryLock;
  25. lparams.use_mlock = @params.UseMemoryLock;
  26. lparams.logits_all = @params.Perplexity;
  27. lparams.embedding = @params.EmbeddingMode;
  28. lparams.low_vram = @params.LowVram;
  29. lparams.n_gqa = @params.GroupedQueryAttention;
  30. lparams.rms_norm_eps = @params.RmsNormEpsilon;
  31. lparams.rope_freq_base = @params.RopeFrequencyBase;
  32. lparams.rope_freq_scale = @params.RopeFrequencyScale;
  33. lparams.mul_mat_q = @params.MulMatQ;
  34. /*
  35. if (@params.TensorSplits.Length != 1)
  36. {
  37. throw new ArgumentException("Currently multi-gpu support is not supported by " +
  38. "both llama.cpp and LLamaSharp.");
  39. }*/
  40. lparams.tensor_split = @params.TensorSplits;
  41. if (!File.Exists(@params.ModelPath))
  42. {
  43. throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
  44. }
  45. var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
  46. var ctx = SafeLLamaContextHandle.Create(model, lparams);
  47. if (!string.IsNullOrEmpty(@params.LoraAdapter))
  48. model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
  49. return ctx;
  50. }
  51. public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
  52. {
  53. var cnt = encoding.GetByteCount(text);
  54. llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
  55. int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
  56. if (n < 0)
  57. {
  58. throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
  59. "specify the encoding.");
  60. }
  61. return res.Take(n);
  62. }
  63. public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
  64. {
  65. var logits = NativeApi.llama_get_logits(ctx);
  66. return new Span<float>(logits, length);
  67. }
  68. public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
  69. {
  70. int result;
  71. fixed(llama_token* p = tokens)
  72. {
  73. result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads);
  74. }
  75. return result;
  76. }
  77. public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
  78. {
  79. return PtrToString(NativeApi.llama_token_to_str(ctx, token), encoding);
  80. }
  81. public static unsafe string PtrToString(IntPtr ptr, Encoding encoding)
  82. {
  83. #if NET6_0_OR_GREATER
  84. if(encoding == Encoding.UTF8)
  85. {
  86. return Marshal.PtrToStringUTF8(ptr);
  87. }
  88. else if(encoding == Encoding.Unicode)
  89. {
  90. return Marshal.PtrToStringUni(ptr);
  91. }
  92. else
  93. {
  94. return Marshal.PtrToStringAuto(ptr);
  95. }
  96. #else
  97. byte* tp = (byte*)ptr.ToPointer();
  98. List<byte> bytes = new();
  99. while (true)
  100. {
  101. byte c = *tp++;
  102. if (c == '\0')
  103. {
  104. break;
  105. }
  106. else
  107. {
  108. bytes.Add(c);
  109. }
  110. }
  111. return encoding.GetString(bytes.ToArray());
  112. #endif
  113. }
  114. }
  115. }