You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. using LLama.Native;
  2. using System;
  3. using LLama.Exceptions;
  4. using LLama.Abstractions;
  5. using Microsoft.Extensions.Logging;
  6. namespace LLama
  7. {
  8. /// <summary>
  9. /// The embedder for LLama, which supports getting embeddings from text.
  10. /// </summary>
  11. public sealed class LLamaEmbedder
  12. : IDisposable
  13. {
  14. /// <summary>
  15. /// Dimension of embedding vectors
  16. /// </summary>
  17. public int EmbeddingSize => Context.EmbeddingSize;
  18. /// <summary>
  19. /// LLama Context
  20. /// </summary>
  21. public LLamaContext Context { get; }
  22. /// <summary>
  23. /// Create a new embedder, using the given LLamaWeights
  24. /// </summary>
  25. /// <param name="weights"></param>
  26. /// <param name="params"></param>
  27. /// <param name="logger"></param>
  28. public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
  29. {
  30. @params.EmbeddingMode = true;
  31. Context = weights.CreateContext(@params, logger);
  32. }
  33. /// <summary>
  34. /// Get the embeddings of the text.
  35. /// </summary>
  36. /// <param name="text"></param>
  37. /// <param name="threads">unused</param>
  38. /// <param name="addBos">Add bos to the text.</param>
  39. /// <param name="encoding">unused</param>
  40. /// <returns></returns>
  41. /// <exception cref="RuntimeError"></exception>
  42. [Obsolete("'threads' and 'encoding' parameters are no longer used")]
  43. // ReSharper disable once MethodOverloadWithOptionalParameter
  44. public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  45. {
  46. return GetEmbeddings(text, addBos);
  47. }
  48. /// <summary>
  49. /// Get the embeddings of the text.
  50. /// </summary>
  51. /// <param name="text"></param>
  52. /// <returns></returns>
  53. /// <exception cref="RuntimeError"></exception>
  54. public float[] GetEmbeddings(string text)
  55. {
  56. return GetEmbeddings(text, true);
  57. }
  58. /// <summary>
  59. /// Get the embeddings of the text.
  60. /// </summary>
  61. /// <param name="text"></param>
  62. /// <param name="addBos">Add bos to the text.</param>
  63. /// <returns></returns>
  64. /// <exception cref="RuntimeError"></exception>
  65. public float[] GetEmbeddings(string text, bool addBos)
  66. {
  67. var embed_inp_array = Context.Tokenize(text, addBos);
  68. // TODO(Rinne): deal with log of prompt
  69. if (embed_inp_array.Length > 0)
  70. Context.Eval(embed_inp_array, 0);
  71. var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
  72. if (embeddings == null)
  73. return Array.Empty<float>();
  74. return embeddings.ToArray();
  75. }
  76. /// <summary>
  77. ///
  78. /// </summary>
  79. public void Dispose()
  80. {
  81. Context.Dispose();
  82. }
  83. }
  84. }