You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. using LLama.Native;
  2. using System;
  3. using System.Text;
  4. using LLama.Exceptions;
  5. using LLama.Abstractions;
  6. namespace LLama
  7. {
  8. /// <summary>
  9. /// The embedder for LLama, which supports getting embeddings from text.
  10. /// </summary>
  11. public class LLamaEmbedder
  12. : IDisposable
  13. {
  14. private readonly SafeLLamaContextHandle _ctx;
  15. /// <summary>
  16. /// Dimension of embedding vectors
  17. /// </summary>
  18. public int EmbeddingSize => _ctx.EmbeddingSize;
  19. /// <summary>
  20. /// Warning: must ensure the original model has params.embedding = true;
  21. /// </summary>
  22. /// <param name="ctx"></param>
  23. internal LLamaEmbedder(SafeLLamaContextHandle ctx)
  24. {
  25. _ctx = ctx;
  26. }
  27. /// <summary>
  28. ///
  29. /// </summary>
  30. /// <param name="params"></param>
  31. public LLamaEmbedder(IModelParams @params)
  32. {
  33. @params.EmbeddingMode = true;
  34. _ctx = Utils.InitLLamaContextFromModelParams(@params);
  35. }
  36. /// <summary>
  37. /// Get the embeddings of the text.
  38. /// </summary>
  39. /// <param name="text"></param>
  40. /// <param name="threads">Threads used for inference.</param>
  41. /// <param name="addBos">Add bos to the text.</param>
  42. /// <param name="encoding"></param>
  43. /// <returns></returns>
  44. /// <exception cref="RuntimeError"></exception>
  45. public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  46. {
  47. if (threads == -1)
  48. {
  49. threads = Math.Max(Environment.ProcessorCount / 2, 1);
  50. }
  51. if (addBos)
  52. {
  53. text = text.Insert(0, " ");
  54. }
  55. var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding));
  56. // TODO(Rinne): deal with log of prompt
  57. if (embed_inp_array.Length > 0)
  58. {
  59. if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, 0, threads) != 0)
  60. {
  61. throw new RuntimeError("Failed to eval.");
  62. }
  63. }
  64. int n_embed = NativeApi.llama_n_embd(_ctx);
  65. var embeddings = NativeApi.llama_get_embeddings(_ctx);
  66. if (embeddings == null)
  67. {
  68. return Array.Empty<float>();
  69. }
  70. var span = new Span<float>(embeddings, n_embed);
  71. float[] res = new float[n_embed];
  72. span.CopyTo(res.AsSpan());
  73. return res;
  74. }
  75. /// <summary>
  76. ///
  77. /// </summary>
  78. public void Dispose()
  79. {
  80. _ctx.Dispose();
  81. }
  82. }
  83. }