You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. using LLama.Native;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using LLama.Exceptions;
  6. using System.Linq;
  7. using LLama.Common;
  8. namespace LLama
  9. {
  10. /// <summary>
  11. /// The embedder for LLama, which supports getting embeddings from text.
  12. /// </summary>
  13. public class LLamaEmbedder : IDisposable
  14. {
  15. SafeLLamaContextHandle _ctx;
  16. /// <summary>
  17. /// Warning: must ensure the original model has params.embedding = true;
  18. /// </summary>
  19. /// <param name="ctx"></param>
  20. internal LLamaEmbedder(SafeLLamaContextHandle ctx)
  21. {
  22. _ctx = ctx;
  23. }
  24. /// <summary>
  25. ///
  26. /// </summary>
  27. /// <param name="params"></param>
  28. public LLamaEmbedder(ModelParams @params)
  29. {
  30. @params.EmbeddingMode = true;
  31. _ctx = Utils.InitLLamaContextFromModelParams(@params);
  32. }
  33. /// <summary>
  34. /// Get the embeddings of the text.
  35. /// </summary>
  36. /// <param name="text"></param>
  37. /// <param name="threads">Threads used for inference.</param>
  38. /// <param name="addBos">Add bos to the text.</param>
  39. /// <param name="encoding"></param>
  40. /// <returns></returns>
  41. /// <exception cref="RuntimeError"></exception>
  42. public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  43. {
  44. if (threads == -1)
  45. {
  46. threads = Math.Max(Environment.ProcessorCount / 2, 1);
  47. }
  48. int n_past = 0;
  49. if (addBos)
  50. {
  51. text = text.Insert(0, " ");
  52. }
  53. var embed_inp = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding));
  54. // TODO(Rinne): deal with log of prompt
  55. if (embed_inp.Count() > 0)
  56. {
  57. var embed_inp_array = embed_inp.ToArray();
  58. if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0)
  59. {
  60. throw new RuntimeError("Failed to eval.");
  61. }
  62. }
  63. int n_embed = NativeApi.llama_n_embd(_ctx);
  64. var embeddings = NativeApi.llama_get_embeddings(_ctx);
  65. if (embeddings == null)
  66. {
  67. return new float[0];
  68. }
  69. var span = new Span<float>(embeddings, n_embed);
  70. float[] res = new float[n_embed];
  71. span.CopyTo(res.AsSpan());
  72. return res;
  73. }
  74. /// <summary>
  75. ///
  76. /// </summary>
  77. public void Dispose()
  78. {
  79. _ctx.Dispose();
  80. }
  81. }
  82. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。