|
|
|
@@ -11,11 +11,29 @@ namespace LLama.Native |
|
|
|
public class SafeLLamaContextHandle |
|
|
|
: SafeLLamaHandleBase |
|
|
|
{ |
|
|
|
#region properties and fields |
|
|
|
/// <summary> |
|
|
|
/// Total number of tokens in vocabulary of this model |
|
|
|
/// </summary> |
|
|
|
public int VocabCount => ThrowIfDisposed().VocabCount; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Total number of tokens in the context |
|
|
|
/// </summary> |
|
|
|
public int ContextSize => ThrowIfDisposed().ContextSize; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Dimension of embedding vectors |
|
|
|
/// </summary> |
|
|
|
public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// This field guarantees that a reference to the model is held for as long as this handle is held |
|
|
|
/// </summary> |
|
|
|
private SafeLlamaModelHandle? _model; |
|
|
|
#endregion |
|
|
|
|
|
|
|
#region construction/destruction |
|
|
|
/// <summary> |
|
|
|
/// Create a new SafeLLamaContextHandle |
|
|
|
/// </summary> |
|
|
|
@@ -44,6 +62,16 @@ namespace LLama.Native |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
private SafeLlamaModelHandle ThrowIfDisposed() |
|
|
|
{ |
|
|
|
if (IsClosed) |
|
|
|
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); |
|
|
|
if (_model == null || _model.IsClosed) |
|
|
|
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); |
|
|
|
|
|
|
|
return _model; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Create a new llama_state for the given model |
|
|
|
/// </summary> |
|
|
|
@@ -59,6 +87,7 @@ namespace LLama.Native |
|
|
|
|
|
|
|
return new(ctx_ptr, model); |
|
|
|
} |
|
|
|
#endregion |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Convert the given text into tokens |
|
|
|
@@ -70,6 +99,8 @@ namespace LLama.Native |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public int[] Tokenize(string text, bool add_bos, Encoding encoding) |
|
|
|
{ |
|
|
|
ThrowIfDisposed(); |
|
|
|
|
|
|
|
// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't |
|
|
|
// possibly be more than this. |
|
|
|
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); |
|
|
|
@@ -97,5 +128,25 @@ namespace LLama.Native |
|
|
|
ArrayPool<int>.Shared.Return(temporaryArray); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Token logits obtained from the last call to llama_eval() |
|
|
|
/// The logits for the last token are stored in the last row |
|
|
|
/// Can be mutated in order to change the probabilities of the next token.<br /> |
|
|
|
/// Rows: n_tokens<br /> |
|
|
|
/// Cols: n_vocab |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
/// <returns></returns> |
|
|
|
public Span<float> GetLogits() |
|
|
|
{ |
|
|
|
var model = ThrowIfDisposed(); |
|
|
|
|
|
|
|
unsafe |
|
|
|
{ |
|
|
|
var logits = NativeApi.llama_get_logits(this); |
|
|
|
return new Span<float>(logits, model.VocabCount); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |