You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Utils.cs 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. using LLama.Abstractions;
  2. using LLama.Native;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. using LLama.Exceptions;
  9. using LLama.Extensions;
  10. namespace LLama
  11. {
  12. using llama_token = Int32;
  13. public static class Utils
  14. {
  15. public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
  16. {
  17. using (@params.ToLlamaContextParams(out var lparams))
  18. {
  19. var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
  20. var ctx = SafeLLamaContextHandle.Create(model, lparams);
  21. if (!string.IsNullOrEmpty(@params.LoraAdapter))
  22. model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
  23. return ctx;
  24. }
  25. }
  26. [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
  27. public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
  28. {
  29. return ctx.Tokenize(text, add_bos, encoding);
  30. }
  31. [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
  32. public static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
  33. {
  34. if (length != ctx.VocabCount)
  35. throw new ArgumentException("length must be the VocabSize");
  36. return ctx.GetLogits();
  37. }
  38. public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
  39. {
  40. int result;
  41. fixed(llama_token* p = tokens)
  42. {
  43. result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads);
  44. }
  45. return result;
  46. }
  47. public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
  48. {
  49. return PtrToString(NativeApi.llama_token_to_str(ctx, token), encoding);
  50. }
  51. public static unsafe string PtrToString(IntPtr ptr, Encoding encoding)
  52. {
  53. #if NET6_0_OR_GREATER
  54. if(encoding == Encoding.UTF8)
  55. {
  56. return Marshal.PtrToStringUTF8(ptr);
  57. }
  58. else if(encoding == Encoding.Unicode)
  59. {
  60. return Marshal.PtrToStringUni(ptr);
  61. }
  62. else
  63. {
  64. return Marshal.PtrToStringAuto(ptr);
  65. }
  66. #else
  67. byte* tp = (byte*)ptr.ToPointer();
  68. List<byte> bytes = new();
  69. while (true)
  70. {
  71. byte c = *tp++;
  72. if (c == '\0')
  73. {
  74. break;
  75. }
  76. else
  77. {
  78. bytes.Add(c);
  79. }
  80. }
  81. return encoding.GetString(bytes.ToArray());
  82. #endif
  83. }
  84. }
  85. }