using System; using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; namespace LLama.Native { /// /// A safe wrapper around a llama_context /// // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API) public sealed class SafeLLamaContextHandle : SafeLLamaHandleBase { #region properties and fields /// /// Total number of tokens in vocabulary of this model /// public int VocabCount => ThrowIfDisposed().VocabCount; /// /// Total number of tokens in the context /// public int ContextSize => NativeApi.llama_n_ctx(this); /// /// Dimension of embedding vectors /// public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; /// /// Get the model which this context is using /// public SafeLlamaModelHandle ModelHandle => ThrowIfDisposed(); private SafeLlamaModelHandle? _model; #endregion #region construction/destruction /// protected override bool ReleaseHandle() { llama_free(handle); SetHandle(IntPtr.Zero); // Decrement refcount on model _model?.DangerousRelease(); _model = null!; return true; } private SafeLlamaModelHandle ThrowIfDisposed() { if (IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); if (_model == null || _model.IsClosed) throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); return _model!; } /// /// Create a new llama_state for the given model /// /// /// /// /// public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) { var ctx = llama_new_context_with_model(model, lparams); if (ctx == null) throw new RuntimeError("Failed to create context from model"); // Increment the model reference count while this context exists. // DangerousAddRef throws if it fails, so there is no need to check "success" ctx._model = model; var success = false; ctx._model.DangerousAddRef(ref success); return ctx; } #endregion #region Native API static SafeLLamaContextHandle() { // This ensures that `NativeApi` has been loaded before calling the two native methods below NativeApi.llama_empty_call(); } /// /// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**! /// /// /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); /// /// Frees all allocated memory in the given llama_context /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_free(IntPtr ctx); #endregion /// /// Token logits obtained from the last call to llama_eval() /// The logits for the last token are stored in the last row /// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
/// Cols: n_vocab ///
/// public Span GetLogits() { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits(this); return new Span(logits, model.VocabCount); } } /// /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// /// /// public Span GetLogitsIth(int i) { var model = ThrowIfDisposed(); unsafe { var logits = NativeApi.llama_get_logits_ith(this, i); return new Span(logits, model.VocabCount); } } #region tokens /// /// Convert the given text into tokens /// /// The text to tokenize /// Whether the "BOS" token should be added /// Encoding to use for the text /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { return ThrowIfDisposed().Tokenize(text, add_bos, special, encoding); } /// /// Convert a single llama token into bytes /// /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` public uint TokenToSpan(LLamaToken token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } #endregion /// /// Run the llama inference to obtain the logits and probabilities for the next token. /// /// The provided batch of new tokens to process /// the number of tokens to use from previous eval calls /// Returns true on success [Obsolete("use llama_decode() instead")] public bool Eval(ReadOnlySpan tokens, int n_past) { unsafe { fixed (LLamaToken* pinned = tokens) { // the entire `eval` system needs replacing with the new batch system! var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); return ret == 0; } } } /// /// /// /// Positive return values does not mean a fatal error, but rather a warning:
/// - 0: success
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// - < 0: error
///
public int Decode(LLamaBatch batch) { using (batch.ToNativeBatch(out var nb)) return NativeApi.llama_decode(this, nb); } #region state /// /// Get the size of the state, when saved as bytes /// public ulong GetStateSize() { return NativeApi.llama_get_state_size(this); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public unsafe ulong GetState(byte* dest, ulong size) { return GetState(new IntPtr(dest), size); } /// /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. /// /// Destination to write to /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small public ulong GetState(IntPtr dest, ulong size) { var required = GetStateSize(); if (size < required) throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); unsafe { return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); } } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public unsafe ulong SetState(byte* src) { return SetState(new IntPtr(src)); } /// /// Set the raw state of this context /// /// The pointer to read the state from /// Number of bytes read from the src pointer public ulong SetState(IntPtr src) { unsafe { return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); } } #endregion /// /// Set the RNG seed /// /// public void SetSeed(uint seed) { NativeApi.llama_set_rng_seed(this, seed); } } }