You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. using LLama.Native;
  2. using System;
  3. using LLama.Exceptions;
  4. using LLama.Abstractions;
  5. using Microsoft.Extensions.Logging;
  6. namespace LLama
  7. {
  8. /// <summary>
  9. /// The embedder for LLama, which supports getting embeddings from text.
  10. /// </summary>
  11. public sealed class LLamaEmbedder
  12. : IDisposable
  13. {
  14. private readonly LLamaContext _ctx;
  15. /// <summary>
  16. /// Dimension of embedding vectors
  17. /// </summary>
  18. public int EmbeddingSize => _ctx.EmbeddingSize;
  19. /// <summary>
  20. /// Create a new embedder (loading temporary weights)
  21. /// </summary>
  22. /// <param name="allParams"></param>
  23. /// <param name="logger"></param>
  24. [Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
  25. public LLamaEmbedder(ILLamaParams allParams, ILogger? logger = null)
  26. : this(allParams, allParams, logger)
  27. {
  28. }
  29. /// <summary>
  30. /// Create a new embedder (loading temporary weights)
  31. /// </summary>
  32. /// <param name="modelParams"></param>
  33. /// <param name="contextParams"></param>
  34. /// <param name="logger"></param>
  35. [Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
  36. public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams, ILogger? logger = null)
  37. {
  38. using var weights = LLamaWeights.LoadFromFile(modelParams);
  39. contextParams.EmbeddingMode = true;
  40. _ctx = weights.CreateContext(contextParams, logger);
  41. }
  42. /// <summary>
  43. /// Create a new embedder, using the given LLamaWeights
  44. /// </summary>
  45. /// <param name="weights"></param>
  46. /// <param name="params"></param>
  47. /// <param name="logger"></param>
  48. public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
  49. {
  50. @params.EmbeddingMode = true;
  51. _ctx = weights.CreateContext(@params, logger);
  52. }
  53. /// <summary>
  54. /// Get the embeddings of the text.
  55. /// </summary>
  56. /// <param name="text"></param>
  57. /// <param name="threads">unused</param>
  58. /// <param name="addBos">Add bos to the text.</param>
  59. /// <param name="encoding">unused</param>
  60. /// <returns></returns>
  61. /// <exception cref="RuntimeError"></exception>
  62. [Obsolete("'threads' and 'encoding' parameters are no longer used")]
  63. // ReSharper disable once MethodOverloadWithOptionalParameter
  64. public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  65. {
  66. return GetEmbeddings(text, addBos);
  67. }
  68. /// <summary>
  69. /// Get the embeddings of the text.
  70. /// </summary>
  71. /// <param name="text"></param>
  72. /// <returns></returns>
  73. /// <exception cref="RuntimeError"></exception>
  74. public float[] GetEmbeddings(string text)
  75. {
  76. return GetEmbeddings(text, true);
  77. }
  78. /// <summary>
  79. /// Get the embeddings of the text.
  80. /// </summary>
  81. /// <param name="text"></param>
  82. /// <param name="addBos">Add bos to the text.</param>
  83. /// <returns></returns>
  84. /// <exception cref="RuntimeError"></exception>
  85. public float[] GetEmbeddings(string text, bool addBos)
  86. {
  87. var embed_inp_array = _ctx.Tokenize(text, addBos);
  88. // TODO(Rinne): deal with log of prompt
  89. if (embed_inp_array.Length > 0)
  90. _ctx.Eval(embed_inp_array, 0);
  91. unsafe
  92. {
  93. var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
  94. if (embeddings == null)
  95. return Array.Empty<float>();
  96. return new Span<float>(embeddings, EmbeddingSize).ToArray();
  97. }
  98. }
  99. /// <summary>
  100. ///
  101. /// </summary>
  102. public void Dispose()
  103. {
  104. _ctx.Dispose();
  105. }
  106. }
  107. }