You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. using LLama.Native;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using LLama.Exceptions;
  6. using System.Linq;
  7. using LLama.Abstractions;
  8. namespace LLama
  9. {
  10. /// <summary>
  11. /// The embedder for LLama, which supports getting embeddings from text.
  12. /// </summary>
  13. public class LLamaEmbedder : IDisposable
  14. {
  15. SafeLLamaContextHandle _ctx;
  16. /// <summary>
  17. /// Warning: must ensure the original model has params.embedding = true;
  18. /// </summary>
  19. /// <param name="ctx"></param>
  20. internal LLamaEmbedder(SafeLLamaContextHandle ctx)
  21. {
  22. _ctx = ctx;
  23. }
  24. /// <summary>
  25. ///
  26. /// </summary>
  27. /// <param name="params"></param>
  28. public LLamaEmbedder(IModelParams @params)
  29. {
  30. @params.EmbeddingMode = true;
  31. _ctx = Utils.InitLLamaContextFromModelParams(@params);
  32. }
  33. /// <summary>
  34. /// Get the embeddings of the text.
  35. /// </summary>
  36. /// <param name="text"></param>
  37. /// <param name="threads">Threads used for inference.</param>
  38. /// <param name="addBos">Add bos to the text.</param>
  39. /// <param name="encoding"></param>
  40. /// <returns></returns>
  41. /// <exception cref="RuntimeError"></exception>
  42. public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  43. {
  44. if (threads == -1)
  45. {
  46. threads = Math.Max(Environment.ProcessorCount / 2, 1);
  47. }
  48. int n_past = 0;
  49. if (addBos)
  50. {
  51. text = text.Insert(0, " ");
  52. }
  53. var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray();
  54. // TODO(Rinne): deal with log of prompt
  55. if (embed_inp_array.Length > 0)
  56. {
  57. if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0)
  58. {
  59. throw new RuntimeError("Failed to eval.");
  60. }
  61. }
  62. int n_embed = NativeApi.llama_n_embd(_ctx);
  63. var embeddings = NativeApi.llama_get_embeddings(_ctx);
  64. if (embeddings == null)
  65. {
  66. return Array.Empty<float>();
  67. }
  68. var span = new Span<float>(embeddings, n_embed);
  69. float[] res = new float[n_embed];
  70. span.CopyTo(res.AsSpan());
  71. return res;
  72. }
  73. /// <summary>
  74. ///
  75. /// </summary>
  76. public void Dispose()
  77. {
  78. _ctx.Dispose();
  79. }
  80. }
  81. }