You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. using LLama.Native;
  2. using System;
  3. using LLama.Exceptions;
  4. using LLama.Abstractions;
  5. namespace LLama
  6. {
  7. /// <summary>
  8. /// The embedder for LLama, which supports getting embeddings from text.
  9. /// </summary>
  10. public sealed class LLamaEmbedder
  11. : IDisposable
  12. {
  13. private readonly LLamaContext _ctx;
  14. /// <summary>
  15. /// Dimension of embedding vectors
  16. /// </summary>
  17. public int EmbeddingSize => _ctx.EmbeddingSize;
  18. /// <summary>
  19. ///
  20. /// </summary>
  21. /// <param name="params"></param>
  22. public LLamaEmbedder(IModelParams @params)
  23. {
  24. @params.EmbeddingMode = true;
  25. using var weights = LLamaWeights.LoadFromFile(@params);
  26. _ctx = weights.CreateContext(@params);
  27. }
  28. public LLamaEmbedder(LLamaWeights weights, IModelParams @params)
  29. {
  30. _ctx = weights.CreateContext(@params);
  31. }
  32. /// <summary>
  33. /// Get the embeddings of the text.
  34. /// </summary>
  35. /// <param name="text"></param>
  36. /// <param name="threads">unused</param>
  37. /// <param name="addBos">Add bos to the text.</param>
  38. /// <param name="encoding">unused</param>
  39. /// <returns></returns>
  40. /// <exception cref="RuntimeError"></exception>
  41. [Obsolete("'threads' and 'encoding' parameters are no longer used")]
  42. // ReSharper disable once MethodOverloadWithOptionalParameter
  43. public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  44. {
  45. return GetEmbeddings(text, addBos);
  46. }
  47. /// <summary>
  48. /// Get the embeddings of the text.
  49. /// </summary>
  50. /// <param name="text"></param>
  51. /// <returns></returns>
  52. /// <exception cref="RuntimeError"></exception>
  53. public float[] GetEmbeddings(string text)
  54. {
  55. return GetEmbeddings(text, true);
  56. }
  57. /// <summary>
  58. /// Get the embeddings of the text.
  59. /// </summary>
  60. /// <param name="text"></param>
  61. /// <param name="addBos">Add bos to the text.</param>
  62. /// <returns></returns>
  63. /// <exception cref="RuntimeError"></exception>
  64. public float[] GetEmbeddings(string text, bool addBos)
  65. {
  66. if (addBos)
  67. {
  68. text = text.Insert(0, " ");
  69. }
  70. var embed_inp_array = _ctx.Tokenize(text, addBos);
  71. // TODO(Rinne): deal with log of prompt
  72. if (embed_inp_array.Length > 0)
  73. _ctx.Eval(embed_inp_array, 0);
  74. unsafe
  75. {
  76. var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
  77. if (embeddings == null)
  78. return Array.Empty<float>();
  79. return new Span<float>(embeddings, EmbeddingSize).ToArray();
  80. }
  81. }
  82. /// <summary>
  83. ///
  84. /// </summary>
  85. public void Dispose()
  86. {
  87. _ctx.Dispose();
  88. }
  89. }
  90. }