|
|
|
@@ -12,22 +12,22 @@ public abstract class BaseSamplingPipeline |
|
|
|
: ISamplingPipeline |
|
|
|
{ |
|
|
|
private int _savedLogitsCount; |
|
|
|
private (int index, float logit)[]? _savedLogits; |
|
|
|
private (LLamaToken index, float logit)[]? _savedLogits; |
|
|
|
|
|
|
|
/// <inheritdoc/> |
|
|
|
public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens) |
|
|
|
{ |
|
|
|
var protectedLogits = GetProtectedTokens(ctx); |
|
|
|
_savedLogitsCount = protectedLogits.Count; |
|
|
|
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); |
|
|
|
_savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount); |
|
|
|
try |
|
|
|
{ |
|
|
|
// Save the values of protected logits |
|
|
|
for (var i = 0; i < protectedLogits.Count; i++) |
|
|
|
{ |
|
|
|
var index = protectedLogits[i]; |
|
|
|
var value = logits[index.Value]; |
|
|
|
_savedLogits[i] = (index.Value, value); |
|
|
|
var value = logits[(int)index]; |
|
|
|
_savedLogits[i] = (index, value); |
|
|
|
} |
|
|
|
|
|
|
|
// Process raw logits |
|
|
|
@@ -47,7 +47,7 @@ public abstract class BaseSamplingPipeline |
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
ArrayPool<(int, float)>.Shared.Return(_savedLogits); |
|
|
|
ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits); |
|
|
|
_savedLogits = null; |
|
|
|
_savedLogitsCount = 0; |
|
|
|
} |
|
|
|
@@ -74,7 +74,7 @@ public abstract class BaseSamplingPipeline |
|
|
|
|
|
|
|
// Restore the values of protected logits |
|
|
|
for (var i = 0; i < saved.Length; i++) |
|
|
|
logits[saved[i].index] = saved[i].logit; |
|
|
|
logits[(int)saved[i].index] = saved[i].logit; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|