You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MirostatSamplingPipeline.cs 1.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. using System;
  2. using LLama.Native;
  3. namespace LLama.Sampling;
  4. /// <summary>
  5. /// A sampling pipeline which uses mirostat (v1) to select tokens
  6. /// </summary>
  7. public class MirostateSamplingPipeline
  8. : BaseSamplingPipeline
  9. {
  10. private const int MIROSTAT_M = 100;
  11. private const float DEFAULT_TAU = 5;
  12. private float _mu = DEFAULT_TAU * 2;
  13. /// <summary>
  14. /// Currently learned mu value
  15. /// </summary>
  16. public float Mu => _mu;
  17. private float _tau = DEFAULT_TAU;
  18. /// <summary>
  19. /// target entropy
  20. /// </summary>
  21. public float Tau
  22. {
  23. get => _tau;
  24. set
  25. {
  26. _tau = value;
  27. _mu = value * 2;
  28. }
  29. }
  30. /// <summary>
  31. /// learning rate
  32. /// </summary>
  33. public float Eta { get; set; } = 0.1f;
  34. /// <inheritdoc />
  35. protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
  36. {
  37. }
  38. /// <inheritdoc />
  39. protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
  40. {
  41. return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu);
  42. }
  43. /// <inheritdoc />
  44. public override void Reset()
  45. {
  46. base.Reset();
  47. _mu = Tau * 2;
  48. }
  49. /// <inheritdoc />
  50. public override ISamplingPipeline Clone()
  51. {
  52. return new MirostateSamplingPipeline
  53. {
  54. Grammar = Grammar?.Clone(),
  55. _mu = _mu,
  56. _tau = _tau,
  57. Eta = Eta
  58. };
  59. }
  60. }