You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Utils.cs 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. using LLama.Native;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using LLama.Exceptions;
  6. using System.Linq;
  7. using System.Runtime.InteropServices;
  8. using System.IO;
  9. #pragma warning disable
  10. // ReSharper disable all
  11. namespace LLama.OldVersion
  12. {
  13. using llama_token = Int32;
  14. internal static class Utils
  15. {
  16. public static SafeLLamaContextHandle llama_init_from_gpt_params(ref LLamaParams @params)
  17. {
  18. var lparams = NativeApi.llama_context_default_params();
  19. lparams.n_ctx = @params.n_ctx;
  20. lparams.n_gpu_layers = @params.n_gpu_layers;
  21. lparams.seed = @params.seed;
  22. lparams.f16_kv = @params.memory_f16;
  23. lparams.use_mmap = @params.use_mmap;
  24. lparams.use_mlock = @params.use_mlock;
  25. lparams.logits_all = @params.perplexity;
  26. lparams.embedding = @params.embedding;
  27. if (!File.Exists(@params.model))
  28. {
  29. throw new FileNotFoundException($"The model file does not exist: {@params.model}");
  30. }
  31. var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams);
  32. var ctx = SafeLLamaContextHandle.Create(model, lparams);
  33. if (!string.IsNullOrEmpty(@params.lora_adapter))
  34. model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads);
  35. return ctx;
  36. }
  37. public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encodingName)
  38. {
  39. var encoding = Encoding.GetEncoding(encodingName);
  40. var cnt = encoding.GetByteCount(text);
  41. llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
  42. int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
  43. if (n < 0)
  44. {
  45. throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
  46. "specify the encoding.");
  47. }
  48. return res.Take(n).ToList();
  49. }
  50. public static unsafe Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length)
  51. {
  52. var logits = NativeApi.llama_get_logits(ctx);
  53. return new Span<float>(logits, length);
  54. }
  55. public static unsafe string PtrToStringUTF8(IntPtr ptr)
  56. {
  57. #if NET6_0_OR_GREATER
  58. return Marshal.PtrToStringUTF8(ptr);
  59. #else
  60. unsafe
  61. {
  62. byte* tp = (byte*)ptr.ToPointer();
  63. List<byte> bytes = new();
  64. while (true)
  65. {
  66. byte c = *tp++;
  67. if (c == '\0')
  68. {
  69. break;
  70. }
  71. else
  72. {
  73. bytes.Add(c);
  74. }
  75. }
  76. return Encoding.UTF8.GetString(bytes.ToArray());
  77. }
  78. #endif
  79. }
  80. }
  81. }