You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DefaultSamplingPipeline.cs 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. using System;
  2. using System.Collections.Generic;
  3. using LLama.Extensions;
  4. using LLama.Native;
  5. namespace LLama.Sampling;
  6. /// <summary>
  7. /// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
  8. /// </summary>
  9. public sealed class DefaultSamplingPipeline
  10. : BaseSamplingPipeline
  11. {
  12. /// <summary>
  13. /// Bias values to add to certain logits
  14. /// </summary>
  15. public Dictionary<int, float> LogitBias { get; } = new();
  16. /// <summary>
  17. /// Grammar to constrain valid tokens
  18. /// </summary>
  19. public SafeLLamaGrammarHandle? Grammar { get; set; }
  20. /// <summary>
  21. /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
  22. /// </summary>
  23. public float RepeatPenalty { get; set; } = 1.1f;
  24. /// <summary>
  25. /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
  26. /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
  27. /// so far, decreasing the model's likelihood to repeat the same line verbatim.
  28. /// </summary>
  29. public float AlphaFrequency
  30. {
  31. get => _alphaFreq;
  32. set
  33. {
  34. if (value < -2)
  35. throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
  36. if (value > 2)
  37. throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
  38. _alphaFreq = value;
  39. }
  40. }
  41. private float _alphaFreq = 0.1f;
  42. /// <summary>
  43. /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
  44. /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
  45. /// text so far, increasing the model's likelihood to talk about new topics.
  46. /// </summary>
  47. public float AlphaPresence
  48. {
  49. get => _alphaPresence;
  50. set
  51. {
  52. if (value < -2)
  53. throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
  54. if (value > 2)
  55. throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
  56. _alphaPresence = value;
  57. }
  58. }
  59. private float _alphaPresence = 0.1f;
  60. /// <summary>
  61. /// Temperature to apply (higher temperature is more "creative")
  62. /// </summary>
  63. public float Temperature { get; set; } = 0.75f;
  64. /// <summary>
  65. /// Number of tokens to keep in TopK sampling
  66. /// </summary>
  67. public int TopK { get; set; }
  68. /// <summary>
  69. /// Z value for tail free sampling
  70. /// </summary>
  71. public float TailFreeZ { get; set; }
  72. /// <summary>
  73. /// P value for locally typical sampling
  74. /// </summary>
  75. public float TypicalP { get; set; }
  76. /// <summary>
  77. /// P value for TopP sampling
  78. /// </summary>
  79. public float TopP { get; set; } = 1f;
  80. /// <summary>
  81. /// P value for MinP sampling
  82. /// </summary>
  83. public float MinP { get; set; }
  84. /// <summary>
  85. /// Whether the newline value should be protected from being modified by logit bias and repeat penalty
  86. /// </summary>
  87. public bool PenalizeNewline { get; set; } = false;
  88. private readonly int[] _newlineToken = new int[1];
  89. /// <inheritdoc />
  90. protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx)
  91. {
  92. if (PenalizeNewline)
  93. return Array.Empty<int>();
  94. _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
  95. return _newlineToken;
  96. }
  97. /// <inheritdoc />
  98. protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
  99. {
  100. foreach (var (key, value) in LogitBias)
  101. logits[key] += value;
  102. }
  103. /// <inheritdoc />
  104. protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
  105. {
  106. // Apply penalties to candidates
  107. candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
  108. // Restore protected tokens, so they are not affected by repetition penalties
  109. RestoreProtectedTokens(candidates);
  110. // Apply the normal llama.cpp pipeline
  111. candidates.ApplyGrammar(ctx, Grammar);
  112. candidates.TopK(ctx, TopK);
  113. candidates.TailFree(ctx, TailFreeZ);
  114. candidates.LocallyTypical(ctx, TypicalP);
  115. candidates.TopP(ctx, TopP);
  116. candidates.MinP(ctx, MinP);
  117. candidates.Temperature(ctx, Temperature);
  118. var id = candidates.SampleToken(ctx);
  119. Grammar?.AcceptToken(ctx, id);
  120. return id;
  121. }
  122. /// <inheritdoc />
  123. protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
  124. {
  125. return candidates.SampleToken(ctx);
  126. }
  127. }