You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MirostatSamplingPipeline.cs 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. using System;
  2. using LLama.Native;
  3. namespace LLama.Sampling;
  4. /// <summary>
  5. /// A sampling pipeline which uses mirostat (v1) to select tokens
  6. /// </summary>
  7. public class MirostateSamplingPipeline
  8. : BaseSamplingPipeline
  9. {
  10. private const int MIROSTAT_M = 100;
  11. private const float DEFAULT_TAU = 5;
  12. private float _mu = DEFAULT_TAU * 2;
  13. /// <summary>
  14. /// Currently learned mu value
  15. /// </summary>
  16. public float Mu => _mu;
  17. private float _tau = DEFAULT_TAU;
  18. /// <summary>
  19. /// target entropy
  20. /// </summary>
  21. public float Tau
  22. {
  23. get => _tau;
  24. set
  25. {
  26. _tau = value;
  27. _mu = value * 2;
  28. }
  29. }
  30. /// <summary>
  31. /// learning rate
  32. /// </summary>
  33. public float Eta { get; set; } = 0.1f;
  34. /// <inheritdoc />
  35. protected override ReadOnlySpan<float> ProcessLogits(SafeLLamaContextHandle ctx, ReadOnlySpan<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  36. {
  37. return logits;
  38. }
  39. /// <inheritdoc />
  40. protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
  41. {
  42. return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu);
  43. }
  44. /// <inheritdoc />
  45. public override void Reset()
  46. {
  47. base.Reset();
  48. _mu = Tau * 2;
  49. }
  50. /// <inheritdoc />
  51. public override ISamplingPipeline Clone()
  52. {
  53. return new MirostateSamplingPipeline
  54. {
  55. Grammar = Grammar?.Clone(),
  56. _mu = _mu,
  57. _tau = _tau,
  58. Eta = Eta
  59. };
  60. }
  61. }