You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaEmbedder.cs 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. using LLama.Native;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using LLama.Exceptions;
  6. namespace LLama.OldVersion
  7. {
  8. public class LLamaEmbedder : IDisposable
  9. {
  10. SafeLLamaContextHandle _ctx;
  11. /// <summary>
  12. /// Warning: must ensure the original model has params.embedding = true;
  13. /// </summary>
  14. /// <param name="ctx"></param>
  15. internal LLamaEmbedder(SafeLLamaContextHandle ctx)
  16. {
  17. _ctx = ctx;
  18. }
  19. public LLamaEmbedder(LLamaParams @params)
  20. {
  21. @params.embedding = true;
  22. _ctx = Utils.llama_init_from_gpt_params(ref @params);
  23. }
  24. public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true, string encoding = "UTF-8")
  25. {
  26. if (n_thread == -1)
  27. {
  28. n_thread = Math.Max(Environment.ProcessorCount / 2, 1);
  29. }
  30. int n_past = 0;
  31. if (add_bos)
  32. {
  33. text = text.Insert(0, " ");
  34. }
  35. var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos, encoding);
  36. // TODO(Rinne): deal with log of prompt
  37. if (embed_inp.Count > 0)
  38. {
  39. var embed_inp_array = embed_inp.ToArray();
  40. if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, n_thread) != 0)
  41. {
  42. throw new RuntimeError("Failed to eval.");
  43. }
  44. }
  45. int n_embed = NativeApi.llama_n_embd(_ctx);
  46. var embeddings = NativeApi.llama_get_embeddings(_ctx);
  47. if (embeddings == null)
  48. {
  49. return new float[0];
  50. }
  51. var span = new Span<float>(embeddings, n_embed);
  52. float[] res = new float[n_embed];
  53. span.CopyTo(res.AsSpan());
  54. return res;
  55. }
  56. public void Dispose()
  57. {
  58. _ctx.Dispose();
  59. }
  60. }
  61. }