Browse Source

- Moved `GetLogits` into `SafeLLamaContextHandle`

- Added disposal check into `SafeLLamaContextHandle`
tags/v0.5.1
Martin Evans 2 years ago
parent
commit
2d811b2603
5 changed files with 64 additions and 10 deletions
  1. +2
    -2
      LLama/LLamaModel.cs
  2. +2
    -2
      LLama/Native/NativeApi.cs
  3. +51
    -0
      LLama/Native/SafeLLamaContextHandle.cs
  4. +3
    -3
      LLama/Native/SafeLlamaModelHandle.cs
  5. +6
    -3
      LLama/Utils.cs

+ 2
- 2
LLama/LLamaModel.cs View File

@@ -284,8 +284,8 @@ namespace LLama
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var n_vocab = NativeApi.llama_n_vocab(_ctx);
var logits = Utils.GetLogits(_ctx, n_vocab);
var n_vocab = _ctx.VocabCount;
var logits = _ctx.GetLogits();

// Apply params.logit_bias map
if(logitBias is not null)


+ 2
- 2
LLama/Native/NativeApi.cs View File

@@ -257,8 +257,8 @@ namespace LLama.Native
/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token
/// Rows: n_tokens
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <param name="ctx"></param>


+ 51
- 0
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -11,11 +11,29 @@ namespace LLama.Native
public class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
#region properties and fields
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => ThrowIfDisposed().VocabCount;

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => ThrowIfDisposed().ContextSize;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount;

/// <summary>
/// This field guarantees that a reference to the model is held for as long as this handle is held
/// </summary>
private SafeLlamaModelHandle? _model;
#endregion

#region construction/destruction
/// <summary>
/// Create a new SafeLLamaContextHandle
/// </summary>
@@ -44,6 +62,16 @@ namespace LLama.Native
return true;
}

private SafeLlamaModelHandle ThrowIfDisposed()
{
if (IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed");
if (_model == null || _model.IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed");

return _model;
}

/// <summary>
/// Create a new llama_state for the given model
/// </summary>
@@ -59,6 +87,7 @@ namespace LLama.Native

return new(ctx_ptr, model);
}
#endregion

/// <summary>
/// Convert the given text into tokens
@@ -70,6 +99,8 @@ namespace LLama.Native
/// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
ThrowIfDisposed();

// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
@@ -97,5 +128,25 @@ namespace LLama.Native
ArrayPool<int>.Shared.Return(temporaryArray);
}
}

/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
public Span<float> GetLogits()
{
var model = ThrowIfDisposed();

unsafe
{
var logits = NativeApi.llama_get_logits(this);
return new Span<float>(logits, model.VocabCount);
}
}
}
}

+ 3
- 3
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -13,17 +13,17 @@ namespace LLama.Native
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount { get; set; }
public int VocabCount { get; }

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize { get; set; }
public int ContextSize { get; }

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingCount { get; set; }
public int EmbeddingCount { get; }

internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)


+ 6
- 3
LLama/Utils.cs View File

@@ -33,10 +33,13 @@ namespace LLama
return ctx.Tokenize(text, add_bos, encoding);
}

public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
[Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
public static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
{
var logits = NativeApi.llama_get_logits(ctx);
return new Span<float>(logits, length);
if (length != ctx.VocabCount)
throw new ArgumentException("length must be the VocabSize");

return ctx.GetLogits();
}

public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)


Loading…
Cancel
Save