You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Utils.cs 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. using LLama.Abstractions;
  2. using LLama.Native;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. using LLama.Exceptions;
  9. using LLama.Extensions;
  10. namespace LLama
  11. {
  12. using llama_token = Int32;
  13. public static class Utils
  14. {
  15. public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
  16. {
  17. using (@params.ToLlamaContextParams(out var lparams))
  18. {
  19. var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
  20. var ctx = SafeLLamaContextHandle.Create(model, lparams);
  21. if (!string.IsNullOrEmpty(@params.LoraAdapter))
  22. model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
  23. return ctx;
  24. }
  25. }
  26. [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
  27. public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
  28. {
  29. return ctx.Tokenize(text, add_bos, encoding);
  30. }
  31. public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
  32. {
  33. var logits = NativeApi.llama_get_logits(ctx);
  34. return new Span<float>(logits, length);
  35. }
  36. public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
  37. {
  38. int result;
  39. fixed(llama_token* p = tokens)
  40. {
  41. result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads);
  42. }
  43. return result;
  44. }
  45. public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
  46. {
  47. return PtrToString(NativeApi.llama_token_to_str(ctx, token), encoding);
  48. }
  49. public static unsafe string PtrToString(IntPtr ptr, Encoding encoding)
  50. {
  51. #if NET6_0_OR_GREATER
  52. if(encoding == Encoding.UTF8)
  53. {
  54. return Marshal.PtrToStringUTF8(ptr);
  55. }
  56. else if(encoding == Encoding.Unicode)
  57. {
  58. return Marshal.PtrToStringUni(ptr);
  59. }
  60. else
  61. {
  62. return Marshal.PtrToStringAuto(ptr);
  63. }
  64. #else
  65. byte* tp = (byte*)ptr.ToPointer();
  66. List<byte> bytes = new();
  67. while (true)
  68. {
  69. byte c = *tp++;
  70. if (c == '\0')
  71. {
  72. break;
  73. }
  74. else
  75. {
  76. bytes.Add(c);
  77. }
  78. }
  79. return encoding.GetString(bytes.ToArray());
  80. #endif
  81. }
  82. }
  83. }