Browse Source

made casts to/from int explicit, fixed places affected

tags/v0.10.0
Martin Evans 1 year ago
parent
commit
2eb52b1630
9 changed files with 20 additions and 20 deletions
  1. +2
    -2
      LLama/LLamaContext.cs
  2. +1
    -1
      LLama/LLamaStatelessExecutor.cs
  3. +1
    -1
      LLama/Native/LLamaBatchSafeHandle.cs
  4. +2
    -2
      LLama/Native/LLamaToken.cs
  5. +2
    -2
      LLama/Native/LLamaTokenData.cs
  6. +1
    -1
      LLama/Native/LLamaTokenDataArray.cs
  7. +1
    -1
      LLama/Native/SafeLLamaContextHandle.cs
  8. +6
    -6
      LLama/Sampling/BaseSamplingPipeline.cs
  9. +4
    -4
      LLama/StreamingTokenDecoder.cs

+ 2
- 2
LLama/LLamaContext.cs View File

@@ -309,12 +309,12 @@ namespace LLama
if (logitBias is not null)
{
foreach (var (key, value) in logitBias)
logits[key.Value] += value;
logits[(int)key] += value;
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token.Value];
var nl_logit = logits[(int)nl_token];

// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);


+ 1
- 1
LLama/LLamaStatelessExecutor.cs View File

@@ -71,7 +71,7 @@ namespace LLama
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<LLamaToken>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);
lastTokens.Add((LLamaToken)0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();


+ 1
- 1
LLama/Native/LLamaBatchSafeHandle.cs View File

@@ -135,7 +135,7 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
NativeBatch.token[NativeBatch.n_tokens] = token.Value;
NativeBatch.token[NativeBatch.n_tokens] = token;
NativeBatch.pos[NativeBatch.n_tokens] = pos;
NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;



+ 2
- 2
LLama/Native/LLamaToken.cs View File

@@ -11,7 +11,7 @@ public readonly record struct LLamaToken
/// <summary>
/// The raw value
/// </summary>
public readonly int Value;
private readonly int Value;

/// <summary>
/// Create a new LLamaToken
@@ -34,5 +34,5 @@ public readonly record struct LLamaToken
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public static implicit operator LLamaToken(int value) => new(value);
public static explicit operator LLamaToken(int value) => new(value);
}

+ 2
- 2
LLama/Native/LLamaTokenData.cs View File

@@ -11,7 +11,7 @@ public struct LLamaTokenData
/// <summary>
/// token id
/// </summary>
public int id;
public LLamaToken id;

/// <summary>
/// log-odds of the token
@@ -29,7 +29,7 @@ public struct LLamaTokenData
/// <param name="id"></param>
/// <param name="logit"></param>
/// <param name="p"></param>
public LLamaTokenData(int id, float logit, float p)
public LLamaTokenData(LLamaToken id, float logit, float p)
{
this.id = id;
this.logit = logit;


+ 1
- 1
LLama/Native/LLamaTokenDataArray.cs View File

@@ -39,7 +39,7 @@ namespace LLama.Native
{
var candidates = new LLamaTokenData[logits.Length];
for (var token_id = 0; token_id < logits.Length; token_id++)
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
candidates[token_id] = new LLamaTokenData((LLamaToken)token_id, logits[token_id], 0.0f);

return new LLamaTokenDataArray(candidates);
}


+ 1
- 1
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -178,7 +178,7 @@ namespace LLama.Native
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(int token, Span<byte> dest)
public int TokenToSpan(LLamaToken token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}


+ 6
- 6
LLama/Sampling/BaseSamplingPipeline.cs View File

@@ -12,22 +12,22 @@ public abstract class BaseSamplingPipeline
: ISamplingPipeline
{
private int _savedLogitsCount;
private (int index, float logit)[]? _savedLogits;
private (LLamaToken index, float logit)[]? _savedLogits;

/// <inheritdoc/>
public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
_savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[index.Value];
_savedLogits[i] = (index.Value, value);
var value = logits[(int)index];
_savedLogits[i] = (index, value);
}

// Process raw logits
@@ -47,7 +47,7 @@ public abstract class BaseSamplingPipeline
}
finally
{
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
@@ -74,7 +74,7 @@ public abstract class BaseSamplingPipeline

// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[saved[i].index] = saved[i].logit;
logits[(int)saved[i].index] = saved[i].logit;
}

/// <summary>


+ 4
- 4
LLama/StreamingTokenDecoder.cs View File

@@ -69,7 +69,7 @@ namespace LLama
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(int token)
public void Add(LLamaToken token)
{
var charsArr = ArrayPool<char>.Shared.Rent(16);
var bytesArr = ArrayPool<byte>.Shared.Rent(16);
@@ -108,7 +108,7 @@ namespace LLama

// Converts a single token into bytes, using the `bytes` array as temporary storage.
// If the `bytes` array is too small it will get a larger one from the ArrayPool.
static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
static Span<byte> TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model)
{
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
@@ -133,9 +133,9 @@ namespace LLama
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(LLamaToken token)
public void Add(int token)
{
Add((int)token);
Add((LLamaToken)token);
}

/// <summary>


Loading…
Cancel
Save