Low level new loading systemtags/v0.4.2-preview
| @@ -13,65 +13,101 @@ namespace LLama.Native | |||||
| /// RNG seed, -1 for random | /// RNG seed, -1 for random | ||||
| /// </summary> | /// </summary> | ||||
| public int seed; | public int seed; | ||||
| /// <summary> | /// <summary> | ||||
| /// text context | /// text context | ||||
| /// </summary> | /// </summary> | ||||
| public int n_ctx; | public int n_ctx; | ||||
| /// <summary> | /// <summary> | ||||
| /// prompt processing batch size | /// prompt processing batch size | ||||
| /// </summary> | /// </summary> | ||||
| public int n_batch; | public int n_batch; | ||||
| /// <summary> | |||||
| /// grouped-query attention (TEMP - will be moved to model hparams) | |||||
| /// </summary> | |||||
| public int n_gqa; | |||||
| /// <summary> | |||||
| /// rms norm epsilon (TEMP - will be moved to model hparams) | |||||
| /// </summary> | |||||
| public float rms_norm_eps; | |||||
| /// <summary> | /// <summary> | ||||
| /// number of layers to store in VRAM | /// number of layers to store in VRAM | ||||
| /// </summary> | /// </summary> | ||||
| public int n_gpu_layers; | public int n_gpu_layers; | ||||
| /// <summary> | /// <summary> | ||||
| /// the GPU that is used for scratch and small tensors | /// the GPU that is used for scratch and small tensors | ||||
| /// </summary> | /// </summary> | ||||
| public int main_gpu; | public int main_gpu; | ||||
| /// <summary> | /// <summary> | ||||
| /// how to split layers across multiple GPUs | /// how to split layers across multiple GPUs | ||||
| /// </summary> | /// </summary> | ||||
| public TensorSplits tensor_split; | |||||
| public float[] tensor_split; | |||||
| /// <summary> | |||||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | |||||
| /// RoPE base frequency | |||||
| /// </summary> | |||||
| public float rope_freq_base; | |||||
| /// <summary> | |||||
| /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 | |||||
| /// RoPE frequency scaling factor | |||||
| /// </summary> | |||||
| public float rope_freq_scale; | |||||
| /// <summary> | /// <summary> | ||||
| /// called with a progress value between 0 and 1, pass NULL to disable | /// called with a progress value between 0 and 1, pass NULL to disable | ||||
| /// </summary> | /// </summary> | ||||
| public IntPtr progress_callback; | public IntPtr progress_callback; | ||||
| /// <summary> | /// <summary> | ||||
| /// context pointer passed to the progress callback | /// context pointer passed to the progress callback | ||||
| /// </summary> | /// </summary> | ||||
| public IntPtr progress_callback_user_data; | public IntPtr progress_callback_user_data; | ||||
| /// <summary> | /// <summary> | ||||
| /// if true, reduce VRAM usage at the cost of performance | /// if true, reduce VRAM usage at the cost of performance | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | [MarshalAs(UnmanagedType.I1)] | ||||
| public bool low_vram; | public bool low_vram; | ||||
| /// <summary> | /// <summary> | ||||
| /// use fp16 for KV cache | /// use fp16 for KV cache | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | [MarshalAs(UnmanagedType.I1)] | ||||
| public bool f16_kv; | public bool f16_kv; | ||||
| /// <summary> | /// <summary> | ||||
| /// the llama_eval() call computes all logits, not just the last one | /// the llama_eval() call computes all logits, not just the last one | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | [MarshalAs(UnmanagedType.I1)] | ||||
| public bool logits_all; | public bool logits_all; | ||||
| /// <summary> | /// <summary> | ||||
| /// only load the vocabulary, no weights | /// only load the vocabulary, no weights | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | [MarshalAs(UnmanagedType.I1)] | ||||
| public bool vocab_only; | public bool vocab_only; | ||||
| /// <summary> | /// <summary> | ||||
| /// use mmap if possible | /// use mmap if possible | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | [MarshalAs(UnmanagedType.I1)] | ||||
| public bool use_mmap; | public bool use_mmap; | ||||
| /// <summary> | /// <summary> | ||||
| /// force system to keep model in RAM | /// force system to keep model in RAM | ||||
| /// </summary> | /// </summary> | ||||
| [MarshalAs(UnmanagedType.I1)] | [MarshalAs(UnmanagedType.I1)] | ||||
| public bool use_mlock; | public bool use_mlock; | ||||
| /// <summary> | /// <summary> | ||||
| /// embedding mode only | /// embedding mode only | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,6 +1,4 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| @@ -29,7 +27,7 @@ namespace LLama.Native | |||||
| } | } | ||||
| private const string libraryName = "libllama"; | private const string libraryName = "libllama"; | ||||
| [DllImport("libllama", EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | |||||
| [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern bool llama_empty_call(); | public static extern bool llama_empty_call(); | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| @@ -56,7 +54,10 @@ namespace LLama.Native | |||||
| /// <param name="params_"></param> | /// <param name="params_"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_init_from_file(string path_model, LLamaContextParams params_); | |||||
| public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_); | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_); | |||||
| /// <summary> | /// <summary> | ||||
| /// not great API - very likely to change. | /// not great API - very likely to change. | ||||
| @@ -65,6 +66,7 @@ namespace LLama.Native | |||||
| /// </summary> | /// </summary> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_backend_init(bool numa); | public static extern void llama_backend_init(bool numa); | ||||
| /// <summary> | /// <summary> | ||||
| /// Frees all allocated memory | /// Frees all allocated memory | ||||
| /// </summary> | /// </summary> | ||||
| @@ -72,6 +74,13 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern void llama_free(IntPtr ctx); | public static extern void llama_free(IntPtr ctx); | ||||
| /// <summary> | |||||
| /// Frees all allocated memory associated with a model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_free_model(IntPtr model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Apply a LoRA adapter to a loaded model | /// Apply a LoRA adapter to a loaded model | ||||
| /// path_base_model is the path to a higher quality model to use as a base for | /// path_base_model is the path to a higher quality model to use as a base for | ||||
| @@ -79,13 +88,13 @@ namespace LLama.Native | |||||
| /// The model needs to be reloaded before applying a new adapter, otherwise the adapter | /// The model needs to be reloaded before applying a new adapter, otherwise the adapter | ||||
| /// will be applied on top of the previous one | /// will be applied on top of the previous one | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx"></param> | |||||
| /// <param name="model_ptr"></param> | |||||
| /// <param name="path_lora"></param> | /// <param name="path_lora"></param> | ||||
| /// <param name="path_base_model"></param> | /// <param name="path_base_model"></param> | ||||
| /// <param name="n_threads"></param> | /// <param name="n_threads"></param> | ||||
| /// <returns>Returns 0 on success</returns> | /// <returns>Returns 0 on success</returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern int llama_apply_lora_from_file(SafeLLamaContextHandle ctx, string path_lora, string path_base_model, int n_threads); | |||||
| public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the number of tokens in the KV cache | /// Returns the number of tokens in the KV cache | ||||
| @@ -294,5 +303,20 @@ namespace LLama.Native | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern IntPtr llama_print_system_info(); | public static extern IntPtr llama_print_system_info(); | ||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model); | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model); | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model); | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken); | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,26 +1,61 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| public class SafeLLamaContextHandle: SafeLLamaHandleBase | |||||
| /// <summary> | |||||
| /// A safe wrapper around a llama_context | |||||
| /// </summary> | |||||
| public class SafeLLamaContextHandle | |||||
| : SafeLLamaHandleBase | |||||
| { | { | ||||
| protected SafeLLamaContextHandle() | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// This field guarantees that a reference to the model is held for as long as this handle is held | |||||
| /// </summary> | |||||
| private SafeLlamaModelHandle? _model; | |||||
| public SafeLLamaContextHandle(IntPtr handle) | |||||
| /// <summary> | |||||
| /// Create a new SafeLLamaContextHandle | |||||
| /// </summary> | |||||
| /// <param name="handle">pointer to an allocated llama_context</param> | |||||
| /// <param name="model">the model which this context was created from</param> | |||||
| public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) | |||||
| : base(handle) | : base(handle) | ||||
| { | { | ||||
| // Increment the model reference count while this context exists | |||||
| _model = model; | |||||
| var success = false; | |||||
| _model.DangerousAddRef(ref success); | |||||
| if (!success) | |||||
| throw new RuntimeError("Failed to increment model refcount"); | |||||
| } | } | ||||
| /// <inheritdoc /> | |||||
| protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
| { | { | ||||
| // Decrement refcount on model | |||||
| _model?.DangerousRelease(); | |||||
| _model = null; | |||||
| NativeApi.llama_free(handle); | NativeApi.llama_free(handle); | ||||
| SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
| return true; | return true; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Create a new llama_state for the given model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="lparams"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) | |||||
| { | |||||
| var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams); | |||||
| if (ctx_ptr == IntPtr.Zero) | |||||
| throw new RuntimeError("Failed to create context from model"); | |||||
| return new(ctx_ptr, model); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,11 +1,13 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| public abstract class SafeLLamaHandleBase: SafeHandle | |||||
| /// <summary> | |||||
| /// Base class for all llama handles to native resources | |||||
| /// </summary> | |||||
| public abstract class SafeLLamaHandleBase | |||||
| : SafeHandle | |||||
| { | { | ||||
| private protected SafeLLamaHandleBase() | private protected SafeLLamaHandleBase() | ||||
| : base(IntPtr.Zero, ownsHandle: true) | : base(IntPtr.Zero, ownsHandle: true) | ||||
| @@ -24,8 +26,10 @@ namespace LLama.Native | |||||
| SetHandle(handle); | SetHandle(handle); | ||||
| } | } | ||||
| /// <inheritdoc /> | |||||
| public override bool IsInvalid => handle == IntPtr.Zero; | public override bool IsInvalid => handle == IntPtr.Zero; | ||||
| /// <inheritdoc /> | |||||
| public override string ToString() | public override string ToString() | ||||
| => $"0x{handle.ToString("x16")}"; | => $"0x{handle.ToString("x16")}"; | ||||
| } | } | ||||
| @@ -0,0 +1,161 @@ | |||||
| using System; | |||||
| using System.Text; | |||||
| using LLama.Exceptions; | |||||
| namespace LLama.Native | |||||
| { | |||||
| /// <summary> | |||||
| /// A reference to a set of llama model weights | |||||
| /// </summary> | |||||
| public class SafeLlamaModelHandle | |||||
| : SafeLLamaHandleBase | |||||
| { | |||||
| /// <summary> | |||||
| /// Total number of tokens in vocabulary of this model | |||||
| /// </summary> | |||||
| public int VocabCount { get; set; } | |||||
| /// <summary> | |||||
| /// Total number of tokens in the context | |||||
| /// </summary> | |||||
| public int ContextSize { get; set; } | |||||
| /// <summary> | |||||
| /// Dimension of embedding vectors | |||||
| /// </summary> | |||||
| public int EmbeddingCount { get; set; } | |||||
| internal SafeLlamaModelHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| VocabCount = NativeApi.llama_n_vocab_from_model(this); | |||||
| ContextSize = NativeApi.llama_n_ctx_from_model(this); | |||||
| EmbeddingCount = NativeApi.llama_n_embd_from_model(this); | |||||
| } | |||||
| /// <inheritdoc /> | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| NativeApi.llama_free_model(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Load a model from the given file path into memory | |||||
| /// </summary> | |||||
| /// <param name="modelPath"></param> | |||||
| /// <param name="lparams"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams) | |||||
| { | |||||
| var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); | |||||
| if (model_ptr == IntPtr.Zero) | |||||
| throw new RuntimeError($"Failed to load model {modelPath}."); | |||||
| return new SafeLlamaModelHandle(model_ptr); | |||||
| } | |||||
| #region LoRA | |||||
| /// <summary> | |||||
| /// Apply a LoRA adapter to a loaded model | |||||
| /// </summary> | |||||
| /// <param name="lora"></param> | |||||
| /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the | |||||
| /// adapter. Can be NULL to use the current loaded model.</param> | |||||
| /// <param name="threads"></param> | |||||
| /// <exception cref="RuntimeError"></exception> | |||||
| public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1) | |||||
| { | |||||
| var err = NativeApi.llama_model_apply_lora_from_file( | |||||
| this, | |||||
| lora, | |||||
| string.IsNullOrEmpty(modelBase) ? null : modelBase, | |||||
| threads | |||||
| ); | |||||
| if (err != 0) | |||||
| throw new RuntimeError("Failed to apply lora adapter."); | |||||
| } | |||||
| #endregion | |||||
| #region tokenize | |||||
| /// <summary> | |||||
| /// Convert a single llama token into string bytes | |||||
| /// </summary> | |||||
| /// <param name="llama_token"></param> | |||||
| /// <returns></returns> | |||||
| public ReadOnlySpan<byte> TokenToSpan(int llama_token) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue); | |||||
| var terminator = bytes.IndexOf((byte)0); | |||||
| return bytes.Slice(0, terminator); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert a single llama token into a string | |||||
| /// </summary> | |||||
| /// <param name="llama_token"></param> | |||||
| /// <param name="encoding">Encoding to use to decode the bytes into a string</param> | |||||
| /// <returns></returns> | |||||
| public string TokenToString(int llama_token, Encoding encoding) | |||||
| { | |||||
| var span = TokenToSpan(llama_token); | |||||
| if (span.Length == 0) | |||||
| return ""; | |||||
| unsafe | |||||
| { | |||||
| fixed (byte* ptr = &span[0]) | |||||
| { | |||||
| return encoding.GetString(ptr, span.Length); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Convert a string of text into tokens | |||||
| /// </summary> | |||||
| /// <param name="text"></param> | |||||
| /// <param name="add_bos"></param> | |||||
| /// <param name="encoding"></param> | |||||
| /// <returns></returns> | |||||
| public int[] Tokenize(string text, bool add_bos, Encoding encoding) | |||||
| { | |||||
| // Convert string to bytes, adding one extra byte to the end (null terminator) | |||||
| var bytesCount = encoding.GetByteCount(text); | |||||
| var bytes = new byte[bytesCount + 1]; | |||||
| unsafe | |||||
| { | |||||
| fixed (char* charPtr = text) | |||||
| fixed (byte* bytePtr = &bytes[0]) | |||||
| { | |||||
| encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length); | |||||
| } | |||||
| } | |||||
| unsafe | |||||
| { | |||||
| fixed (byte* bytesPtr = &bytes[0]) | |||||
| { | |||||
| // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) | |||||
| var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos); | |||||
| // Tokenize again, this time outputting into an array of exactly the right size | |||||
| var tokens = new int[count]; | |||||
| fixed (int* tokensPtr = &tokens[0]) | |||||
| { | |||||
| count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); | |||||
| return tokens; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #endregion | |||||
| } | |||||
| } | |||||
| @@ -31,24 +31,12 @@ namespace LLama.OldVersion | |||||
| throw new FileNotFoundException($"The model file does not exist: {@params.model}"); | throw new FileNotFoundException($"The model file does not exist: {@params.model}"); | ||||
| } | } | ||||
| var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams); | |||||
| if (ctx_ptr == IntPtr.Zero) | |||||
| { | |||||
| throw new RuntimeError($"Failed to load model {@params.model}."); | |||||
| } | |||||
| SafeLLamaContextHandle ctx = new(ctx_ptr); | |||||
| var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams); | |||||
| var ctx = SafeLLamaContextHandle.Create(model, lparams); | |||||
| if (!string.IsNullOrEmpty(@params.lora_adapter)) | if (!string.IsNullOrEmpty(@params.lora_adapter)) | ||||
| { | |||||
| int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter, | |||||
| string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads); | |||||
| if (err != 0) | |||||
| { | |||||
| throw new RuntimeError("Failed to apply lora adapter."); | |||||
| } | |||||
| } | |||||
| model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads); | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| @@ -28,40 +28,25 @@ namespace LLama | |||||
| lparams.logits_all = @params.Perplexity; | lparams.logits_all = @params.Perplexity; | ||||
| lparams.embedding = @params.EmbeddingMode; | lparams.embedding = @params.EmbeddingMode; | ||||
| lparams.low_vram = @params.LowVram; | lparams.low_vram = @params.LowVram; | ||||
| if(@params.TensorSplits.Length != 1) | |||||
| if (@params.TensorSplits.Length != 1) | |||||
| { | { | ||||
| throw new ArgumentException("Currently multi-gpu support is not supported by " + | throw new ArgumentException("Currently multi-gpu support is not supported by " + | ||||
| "both llama.cpp and LLamaSharp."); | "both llama.cpp and LLamaSharp."); | ||||
| } | } | ||||
| lparams.tensor_split = new TensorSplits() | |||||
| { | |||||
| Item1 = @params.TensorSplits[0] | |||||
| }; | |||||
| lparams.tensor_split = @params.TensorSplits; | |||||
| if (!File.Exists(@params.ModelPath)) | if (!File.Exists(@params.ModelPath)) | ||||
| { | { | ||||
| throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); | throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); | ||||
| } | } | ||||
| var ctx_ptr = NativeApi.llama_init_from_file(@params.ModelPath, lparams); | |||||
| if (ctx_ptr == IntPtr.Zero) | |||||
| { | |||||
| throw new RuntimeError($"Failed to load model {@params.ModelPath}."); | |||||
| } | |||||
| SafeLLamaContextHandle ctx = new(ctx_ptr); | |||||
| var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); | |||||
| var ctx = SafeLLamaContextHandle.Create(model, lparams); | |||||
| if (!string.IsNullOrEmpty(@params.LoraAdapter)) | if (!string.IsNullOrEmpty(@params.LoraAdapter)) | ||||
| { | |||||
| int err = NativeApi.llama_apply_lora_from_file(ctx, @params.LoraAdapter, | |||||
| string.IsNullOrEmpty(@params.LoraBase) ? null : @params.LoraBase, @params.Threads); | |||||
| if (err != 0) | |||||
| { | |||||
| throw new RuntimeError("Failed to apply lora adapter."); | |||||
| } | |||||
| } | |||||
| model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| @@ -78,7 +63,7 @@ namespace LLama | |||||
| return res.Take(n); | return res.Take(n); | ||||
| } | } | ||||
| public unsafe static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | |||||
| public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | |||||
| { | { | ||||
| var logits = NativeApi.llama_get_logits(ctx); | var logits = NativeApi.llama_get_logits(ctx); | ||||
| return new Span<float>(logits, length); | return new Span<float>(logits, length); | ||||