You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. using LLama.Native;
  2. using System;
  3. using LLama.Exceptions;
  4. using LLama.Abstractions;
  5. namespace LLama
  6. {
  7. /// <summary>
  8. /// The embedder for LLama, which supports getting embeddings from text.
  9. /// </summary>
  10. public sealed class LLamaEmbedder
  11. : IDisposable
  12. {
  13. private readonly LLamaContext _ctx;
  14. /// <summary>
  15. /// Dimension of embedding vectors
  16. /// </summary>
  17. public int EmbeddingSize => _ctx.EmbeddingSize;
  18. /// <summary>
  19. /// Create a new embedder (loading temporary weights)
  20. /// </summary>
  21. /// <param name="allParams"></param>
  22. [Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
  23. public LLamaEmbedder(ILLamaParams allParams)
  24. : this(allParams, allParams)
  25. {
  26. }
  27. /// <summary>
  28. /// Create a new embedder (loading temporary weights)
  29. /// </summary>
  30. /// <param name="modelParams"></param>
  31. /// <param name="contextParams"></param>
  32. [Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
  33. public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
  34. {
  35. using var weights = LLamaWeights.LoadFromFile(modelParams);
  36. contextParams.EmbeddingMode = true;
  37. _ctx = weights.CreateContext(contextParams);
  38. }
  39. /// <summary>
  40. /// Create a new embedder, using the given LLamaWeights
  41. /// </summary>
  42. /// <param name="weights"></param>
  43. /// <param name="params"></param>
  44. public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
  45. {
  46. @params.EmbeddingMode = true;
  47. _ctx = weights.CreateContext(@params);
  48. }
  49. /// <summary>
  50. /// Get the embeddings of the text.
  51. /// </summary>
  52. /// <param name="text"></param>
  53. /// <param name="threads">unused</param>
  54. /// <param name="addBos">Add bos to the text.</param>
  55. /// <param name="encoding">unused</param>
  56. /// <returns></returns>
  57. /// <exception cref="RuntimeError"></exception>
  58. [Obsolete("'threads' and 'encoding' parameters are no longer used")]
  59. // ReSharper disable once MethodOverloadWithOptionalParameter
  60. public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  61. {
  62. return GetEmbeddings(text, addBos);
  63. }
  64. /// <summary>
  65. /// Get the embeddings of the text.
  66. /// </summary>
  67. /// <param name="text"></param>
  68. /// <returns></returns>
  69. /// <exception cref="RuntimeError"></exception>
  70. public float[] GetEmbeddings(string text)
  71. {
  72. return GetEmbeddings(text, true);
  73. }
  74. /// <summary>
  75. /// Get the embeddings of the text.
  76. /// </summary>
  77. /// <param name="text"></param>
  78. /// <param name="addBos">Add bos to the text.</param>
  79. /// <returns></returns>
  80. /// <exception cref="RuntimeError"></exception>
  81. public float[] GetEmbeddings(string text, bool addBos)
  82. {
  83. var embed_inp_array = _ctx.Tokenize(text, addBos);
  84. // TODO(Rinne): deal with log of prompt
  85. if (embed_inp_array.Length > 0)
  86. _ctx.Eval(embed_inp_array, 0);
  87. unsafe
  88. {
  89. var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
  90. if (embeddings == null)
  91. return Array.Empty<float>();
  92. return new Span<float>(embeddings, EmbeddingSize).ToArray();
  93. }
  94. }
  95. /// <summary>
  96. ///
  97. /// </summary>
  98. public void Dispose()
  99. {
  100. _ctx.Dispose();
  101. }
  102. }
  103. }