using LLama.Exceptions; using LLama.Native; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.IO; using System.IO.MemoryMappedFiles; using LLama.Common; using System.Runtime.InteropServices; using System.Threading.Tasks; using LLama.Extensions; using LLama.Abstractions; using LLama.Sampling; using Microsoft.Extensions.Logging; using System.Threading; namespace LLama { /// /// A llama_context, which holds all the context required to interact with a model /// public sealed class LLamaContext : IDisposable { private readonly ILogger? _logger; /// /// Total number of tokens in vocabulary of this model /// public int VocabCount => NativeHandle.VocabCount; /// /// Total number of tokens in the context /// public uint ContextSize => NativeHandle.ContextSize; /// /// Dimension of embedding vectors /// public int EmbeddingSize => NativeHandle.EmbeddingSize; /// /// The context params set for this context /// public IContextParams Params { get; } /// /// The native handle, which is used to be passed to the native APIs /// /// Be careful how you use this! public SafeLLamaContextHandle NativeHandle { get; } /// /// The encoding set for this model to deal with text input. /// public Encoding Encoding { get; } private uint _generationThreads; private uint _batchThreads; /// /// Get or set the number of threads to use for generation /// public uint GenerationThreads { get => _generationThreads; set { _generationThreads = value; NativeHandle.SetThreads(_generationThreads, _batchThreads); } } /// /// Get or set the number of threads to use for batch processing /// public uint BatchThreads { get => _batchThreads; set { _batchThreads = value; NativeHandle.SetThreads(_generationThreads, _batchThreads); } } /// /// Get the maximum batch size for this context /// public uint BatchSize => NativeHandle.BatchSize; /// /// Create a new LLamaContext for the given LLamaWeights /// /// /// /// /// public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null) { if (model.NativeHandle.IsClosed) throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); Params = @params; _logger = logger; Encoding = @params.Encoding; @params.ToLlamaContextParams(out var lparams); NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); // It's not possible to get these values from llama.cpp, store a copy of them here. _generationThreads = lparams.n_threads; _batchThreads = lparams.n_threads_batch; } /// /// Set the seed for the RNG /// /// public void SetSeed(uint seed) { NativeHandle.SetSeed(seed); } /// /// Tokenize a string. /// /// /// Whether to add a bos to the text. /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = false) { return NativeHandle.Tokenize(text, addBos, special, Encoding); } /// /// Detokenize the tokens to text. /// /// /// [Obsolete("Use a `StreamingTokenDecoder` instead")] public string DeTokenize(IReadOnlyList tokens) { // Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder! // It should be kept around for the entire time you are decoding one stream of tokens. var decoder = new StreamingTokenDecoder(this); decoder.AddRange(tokens); return decoder.Read(); } #region state load/save /// /// Save the state to specified path. /// /// public void SaveState(string filename) { // Delete that file before overwriting it if (File.Exists(filename)) File.Delete(filename); // Estimate size of state to write to disk, this is always equal to or greater than the actual size var estimatedStateSize = checked((long)NativeHandle.GetStateSize()); // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array long writtenBytes; using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize)) using (var view = file.CreateViewAccessor(0, estimatedStateSize)) { unsafe { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); try { writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize); } finally { view.SafeMemoryMappedViewHandle.ReleasePointer(); } } } // Truncate the file to the actual size of data that was written using (var fileStream = new FileStream(filename, FileMode.Open)) fileStream.SetLength(writtenBytes); } /// /// Save the state of a particular sequence to specified path. /// /// /// public void SaveState(string filename, LLamaSeqId sequence) { // Delete that file before overwriting it if (File.Exists(filename)) File.Delete(filename); // Estimate size of state to write to disk, this is always equal to or greater than the actual size var estimatedStateSize = checked((long)NativeHandle.GetStateSize(sequence)); // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array long writtenBytes; using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize)) using (var view = file.CreateViewAccessor(0, estimatedStateSize)) { unsafe { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); try { writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize, sequence); } finally { view.SafeMemoryMappedViewHandle.ReleasePointer(); } } } // Truncate the file to the actual size of data that was written using (var fileStream = new FileStream(filename, FileMode.Open)) fileStream.SetLength(writtenBytes); } /// /// Get the state data as an opaque handle, which can be loaded later using /// /// Use if you intend to save this state to disk. /// public State GetState() { var stateSize = NativeHandle.GetStateSize(); // Allocate a chunk of memory large enough to hold the entire state var memory = Marshal.AllocHGlobal((nint)stateSize); try { // Copy the state data into memory, discover the actual size required ulong actualSize; unsafe { actualSize = NativeHandle.GetState((byte*)memory, stateSize); } // Shrink to size memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); // Wrap memory in a "state" var state = new State(memory, actualSize); // Set memory to zero, to prevent it being freed in finally block memory = IntPtr.Zero; return state; } finally { if (memory != IntPtr.Zero) Marshal.FreeHGlobal(memory); } } /// /// Get the state data as an opaque handle, which can be loaded later using /// /// Use if you intend to save this state to disk. /// public SequenceState GetState(LLamaSeqId sequence) { var stateSize = NativeHandle.GetStateSize(sequence); // Allocate a chunk of memory large enough to hold the entire state var memory = Marshal.AllocHGlobal((nint)stateSize); try { // Copy the state data into memory, discover the actual size required ulong actualSize; unsafe { actualSize = NativeHandle.GetState((byte*)memory, stateSize, sequence); } // Shrink to size memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); // Wrap memory in a "state" var state = new SequenceState(memory, actualSize); // Set memory to zero, to prevent it being freed in finally block memory = IntPtr.Zero; return state; } finally { if (memory != IntPtr.Zero) Marshal.FreeHGlobal(memory); } } /// /// Load the state from specified path. /// /// public void LoadState(string filename) { // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null)) using (var view = file.CreateViewAccessor()) { unsafe { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); try { NativeHandle.SetState(ptr); } finally { view.SafeMemoryMappedViewHandle.ReleasePointer(); } } } } /// /// Load the state from specified path into a particular sequence /// /// /// public void LoadState(string filename, LLamaSeqId sequence) { // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null)) using (var view = file.CreateViewAccessor()) { unsafe { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); try { NativeHandle.SetState(ptr, sequence); } finally { view.SafeMemoryMappedViewHandle.ReleasePointer(); } } } } /// /// Load the state from memory. /// /// /// public void LoadState(State state) { unsafe { NativeHandle.SetState((byte*)state.DangerousGetHandle()); } } /// /// Load the state from memory into a particular sequence /// /// /// /// public void LoadState(SequenceState state, LLamaSeqId sequence) { unsafe { NativeHandle.SetState((byte*)state.DangerousGetHandle(), sequence); } } #endregion /// /// Sample a single token from this context, using the given sampling pipeline /// /// The pipeline to use to process the logits and to select a token /// The tokens recently returned from the model /// The selected token public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) { var token = pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); pipeline.Accept(NativeHandle, token); return token; } /// /// Perform the sampling. Please don't use it unless you fully know what it does. /// /// /// /// /// /// /// /// /// /// /// /// /// /// public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, SafeLLamaGrammarHandle? grammar, float minP) { LLamaToken id; if (grammar != null) { candidates.ApplyGrammar(NativeHandle, grammar); } if (temperature <= 0) { // Greedy sampling id = candidates.SampleTokenGreedy(NativeHandle); } else { var mu = mirostat_mu ?? (2 * mirostatTau); { if (mirostat == MirostatType.Mirostat) { const int mirostat_m = 100; candidates.Temperature(NativeHandle, temperature); id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu); } else if (mirostat == MirostatType.Mirostat2) { candidates.Temperature(NativeHandle, temperature); id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu); } else { candidates.TopK(NativeHandle, topK); candidates.TailFree(NativeHandle, tfsZ); candidates.LocallyTypical(NativeHandle, typicalP); candidates.TopP(NativeHandle, topP); candidates.MinP(NativeHandle, minP); candidates.Temperature(NativeHandle, temperature); id = candidates.SampleToken(NativeHandle); } } mirostat_mu = mu; } grammar?.AcceptToken(NativeHandle, id); return id; } /// /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. /// /// /// /// /// /// /// /// /// /// public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable lastTokens, Dictionary? logitBias = null, int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { var logits = NativeHandle.GetLogitsIth(logits_i); // Apply params.logit_bias map if (logitBias is not null) { foreach (var (key, value) in logitBias) logits[(int)key] += value; } // Save the newline logit value var nl_token = NativeHandle.ModelHandle.Tokens.Newline; var nl_logit = logits[(int?)nl_token ?? 0]; // Convert logits into token candidates var candidates_p = LLamaTokenDataArray.Create(logits); // Extract most recently returned tokens var last_n_repeat = Math.Min((int)ContextSize, repeatLastTokensCount); var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); // Apply penalties to candidates candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); // Restore newline token logit value if necessary if (!penalizeNL && nl_token.HasValue) { var candidatesSpan = candidates_p.data.Span; for (var i = 0; i < candidates_p.data.Length; i++) { ref var item = ref candidatesSpan[i]; if (item.id == nl_token) item.logit = nl_logit; } candidates_p.sorted = false; } return candidates_p; } #region eval overloads /// /// /// public DecodeResult Decode(LLamaBatch batch) { if (batch.TokenCount == 0) return 0; if (batch.TokenCount > Params.BatchSize) throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); return (DecodeResult)NativeHandle.Decode(batch); } /// /// /// /// public Task DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) { return Task.Run(() => Decode(batch), cancellationToken); } #endregion /// public void Dispose() { NativeHandle.Dispose(); } /// /// The state of this context, which can be reloaded later /// public class State : SafeLLamaHandleBase { private readonly ulong _size; /// /// Get the size in bytes of this state object /// public ulong Size => _size; internal State(IntPtr memory, ulong size) : base(memory, true) { _size = size; } /// protected override bool ReleaseHandle() { Marshal.FreeHGlobal(handle); return true; } /// /// Convert this state to a byte array /// /// [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")] public byte[] ToByteArray() { var bytes = new byte[_size]; Marshal.Copy(handle, bytes, 0, (int)_size); return bytes; } /// /// Load state from a byte array /// /// /// [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")] public static State FromByteArray(byte[] bytes) { var memory = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, memory, bytes.Length); return new State(memory, (ulong)bytes.Length); } } /// /// The state of a single sequence, which can be reloaded later /// public class SequenceState : SafeLLamaHandleBase { private readonly ulong _size; /// /// Get the size in bytes of this state object /// public ulong Size => _size; internal SequenceState(IntPtr memory, ulong size) : base(memory, true) { _size = size; } /// protected override bool ReleaseHandle() { Marshal.FreeHGlobal(handle); return true; } /// /// Copy bytes to a destination pointer. /// /// Destination to write to /// Length of the destination buffer /// Offset from start of src to start copying from /// Number of bytes written to destination public unsafe ulong CopyTo(byte* dst, ulong length, ulong offset = 0) { var copy = Math.Min(length, _size - offset); var src = (byte*)DangerousGetHandle(); src += offset; Buffer.MemoryCopy(src, dst, length, copy); return copy; } } } }