You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. using LLama.Native;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using LLama.Exceptions;
  6. using LLama.Abstractions.Params;
  7. using System.Linq;
  8. namespace LLama
  9. {
  10. public class LLamaEmbedder : IDisposable
  11. {
  12. SafeLLamaContextHandle _ctx;
  13. /// <summary>
  14. /// Warning: must ensure the original model has params.embedding = true;
  15. /// </summary>
  16. /// <param name="ctx"></param>
  17. internal LLamaEmbedder(SafeLLamaContextHandle ctx)
  18. {
  19. _ctx = ctx;
  20. }
  21. public LLamaEmbedder(ModelParams @params)
  22. {
  23. @params.EmbeddingMode = true;
  24. _ctx = Utils.InitLLamaContextFromModelParams(@params);
  25. }
  26. /// <summary>
  27. /// Get the embeddings of the text.
  28. /// </summary>
  29. /// <param name="text"></param>
  30. /// <param name="threads">Threads used for inference.</param>
  31. /// <param name="addBos">Add bos to the text.</param>
  32. /// <param name="encoding"></param>
  33. /// <returns></returns>
  34. /// <exception cref="RuntimeError"></exception>
  35. public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
  36. {
  37. if (threads == -1)
  38. {
  39. threads = Math.Max(Environment.ProcessorCount / 2, 1);
  40. }
  41. int n_past = 0;
  42. if (addBos)
  43. {
  44. text = text.Insert(0, " ");
  45. }
  46. var embed_inp = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding));
  47. // TODO(Rinne): deal with log of prompt
  48. if (embed_inp.Count() > 0)
  49. {
  50. var embed_inp_array = embed_inp.ToArray();
  51. if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, threads) != 0)
  52. {
  53. throw new RuntimeError("Failed to eval.");
  54. }
  55. }
  56. int n_embed = NativeApi.llama_n_embd(_ctx);
  57. var embeddings = NativeApi.llama_get_embeddings(_ctx);
  58. if (embeddings == null)
  59. {
  60. return new float[0];
  61. }
  62. var span = new Span<float>(embeddings, n_embed);
  63. float[] res = new float[n_embed];
  64. span.CopyTo(res.AsSpan());
  65. return res;
  66. }
  67. public void Dispose()
  68. {
  69. _ctx.Dispose();
  70. }
  71. }
  72. }

C#/.NET上易用的LLM高性能推理框架,支持LLaMA和LLaVA系列模型。