diff --git a/LLama.Examples/Assets/vicuna-llava-v16.txt b/LLama.Examples/Assets/vicuna-llava-v16.txt new file mode 100644 index 00000000..7ba4018b --- /dev/null +++ b/LLama.Examples/Assets/vicuna-llava-v16.txt @@ -0,0 +1 @@ +\nUSER:\nProvide a full description.\nASSISTANT:\n diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index b74170e3..4016b401 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -13,6 +13,7 @@ public class ExampleRunner { "Chat Session: Automatic conversation", TalkToYourself.Run }, { "Chat Session: Chinese characters", ChatChineseGB2312.Run }, { "Executor: Interactive mode chat", InteractiveModeExecute.Run }, + { "Executor: Llava Interactive mode chat", LlavaInteractiveModeExecute.Run }, { "Executor: Instruct mode chat", InstructModeExecute.Run }, { "Executor: Stateless mode chat", StatelessModeExecute.Run }, { "Save and Load: chat session", SaveAndLoadSession.Run }, diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs new file mode 100644 index 00000000..1cb8e8fd --- /dev/null +++ b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs @@ -0,0 +1,51 @@ +using LLama.Common; + +namespace LLama.Examples.Examples +{ + public class LlavaInteractiveModeExecute + { + public static async Task Run() + { + string multiModalProj = UserSettings.GetMMProjPath(); + string modelPath = UserSettings.GetModelPath(); + string imagePath = UserSettings.GetImagePath(); + + var prompt = (await File.ReadAllTextAsync("Assets/vicuna-llava-v16.txt")).Trim(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 4096, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + + // Llava Init + using var clipModel = LLavaWeights.LoadFromFile(multiModalProj); + + var ex = new InteractiveExecutor(context, clipModel ); + + ex.ImagePath = imagePath; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 1024 and the context size is 4096. "); + Console.ForegroundColor = ConsoleColor.White; + + Console.Write(prompt); + + var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "USER:" }, MaxTokens = 1024 }; + + while (true) + { + await foreach (var text in ex.InferAsync(prompt, inferenceParams)) + { + Console.Write(text); + } + Console.ForegroundColor = ConsoleColor.Green; + prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.White; + } + } + } +} diff --git a/LLama.Examples/Examples/StatelessModeExecute.cs b/LLama.Examples/Examples/StatelessModeExecute.cs index e46a024e..762cd24d 100644 --- a/LLama.Examples/Examples/StatelessModeExecute.cs +++ b/LLama.Examples/Examples/StatelessModeExecute.cs @@ -21,7 +21,7 @@ namespace LLama.Examples.Examples Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + "no impact on the current response. Now you can ask it questions. Note that in this example, no prompt was set for LLM and the maximum response tokens is 50. " + - "It may not perform well because of lack of prompt. This is also an example that could indicate the improtance of prompt in LLM. To improve it, you can add " + + "It may not perform well because of lack of prompt. This is also an example that could indicate the importance of prompt in LLM. To improve it, you can add " + "a prompt for it yourself!"); Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index ec6bb10d..c6667c5a 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -67,6 +67,9 @@ PreserveNewest + + PreserveNewest + diff --git a/LLama.Examples/UserSettings.cs b/LLama.Examples/UserSettings.cs index 088a628e..c21a9075 100644 --- a/LLama.Examples/UserSettings.cs +++ b/LLama.Examples/UserSettings.cs @@ -4,54 +4,84 @@ namespace LLama.Examples; internal static class UserSettings { - private static readonly string SettingsFilePath = Path.Join(AppContext.BaseDirectory, "DefaultModel.env"); + private static readonly string SettingsModelPath = Path.Join(AppContext.BaseDirectory, "DefaultModel.env"); + private static readonly string SettingsMMprojPath = Path.Join(AppContext.BaseDirectory, "DefaultMMProj.env"); + private static readonly string SettingsImagePath = Path.Join(AppContext.BaseDirectory, "DefaultImage.env"); - private static string? ReadDefaultModelPath() + private static string? ReadDefaultPath(string file) { - if (!File.Exists(SettingsFilePath)) + if (!File.Exists(file)) return null; - string path = File.ReadAllText(SettingsFilePath).Trim(); + string path = File.ReadAllText(file).Trim(); if (!File.Exists(path)) return null; return path; } - private static void WriteDefaultModelPath(string path) + private static void WriteDefaultPath(string settings, string path) { - File.WriteAllText(SettingsFilePath, path); + File.WriteAllText(settings, path); } public static string GetModelPath(bool alwaysPrompt = false) { - var defaultPath = ReadDefaultModelPath(); + var defaultPath = ReadDefaultPath(SettingsModelPath); var path = defaultPath is null || alwaysPrompt ? PromptUserForPath() : PromptUserForPathWithDefault(defaultPath); if (File.Exists(path)) - WriteDefaultModelPath(path); + WriteDefaultPath(SettingsModelPath, path); return path; } + + // TODO: Refactorize + public static string GetMMProjPath(bool alwaysPrompt = false) + { + var defaultPath = ReadDefaultPath(SettingsMMprojPath); + var path = defaultPath is null || alwaysPrompt + ? PromptUserForPath("MMProj") + : PromptUserForPathWithDefault(defaultPath, "MMProj"); + + if (File.Exists(path)) + WriteDefaultPath(SettingsMMprojPath, path); + + return path; + } + + // TODO: Refactorize + public static string GetImagePath(bool alwaysPrompt = false) + { + var defaultPath = ReadDefaultPath(SettingsImagePath); + var path = defaultPath is null || alwaysPrompt + ? PromptUserForPath("image") + : PromptUserForPathWithDefault(defaultPath, "image"); + + if (File.Exists(path)) + WriteDefaultPath(SettingsImagePath, path); + + return path; + } - private static string PromptUserForPath() + private static string PromptUserForPath(string text = "model") { return AnsiConsole.Prompt( - new TextPrompt("Please input your model path:") + new TextPrompt(string.Format("Please input your {0} path:", text) ) .PromptStyle("white") - .Validate(File.Exists, "[red]ERROR: invalid model file path - file does not exist[/]") + .Validate(File.Exists, string.Format("[red]ERROR: invalid {0} file path - file does not exist[/]", text) ) ); } - private static string PromptUserForPathWithDefault(string defaultPath) + private static string PromptUserForPathWithDefault(string defaultPath, string text = "model") { return AnsiConsole.Prompt( - new TextPrompt("Please input your model path (or ENTER for default):") + new TextPrompt(string.Format("Please input your {0} path (or ENTER for default):", text) ) .DefaultValue(defaultPath) .PromptStyle("white") - .Validate(File.Exists, "[red]ERROR: invalid model file path - file does not exist[/]") + .Validate(File.Exists, string.Format("[red]ERROR: invalid {0} file path - file does not exist[/]", text)) ); } } diff --git a/LLama.Unittest/LLavaWeightsTests.cs b/LLama.Unittest/LLavaWeightsTests.cs index e1fe1065..0e460bc5 100644 --- a/LLama.Unittest/LLavaWeightsTests.cs +++ b/LLama.Unittest/LLavaWeightsTests.cs @@ -31,23 +31,23 @@ namespace LLama.Unittest _llamaWeights.Dispose(); _lLavaWeights.Dispose(); } - - - [Fact(Skip = "Very slow in CI")] + [Fact] public void EmbedImageAsFileName() { int n_past = 0; - Assert.True( _lLavaWeights.EmbedImage( _context, Constants.LLavaImage, ref n_past ) ); - } - - [Fact(Skip = "Very slow in CI")] + SafeLlavaImageEmbedHandle emb = _lLavaWeights.CreateImageEmbeddings(_context, Constants.LLavaImage); + Assert.True( _lLavaWeights.EvalImageEmbed( _context, emb, ref n_past ) ); + } + + [Fact] public void EmbedImageAsBinary() { int n_past = 0; byte[] image = System.IO.File.ReadAllBytes(Constants.LLavaImage); - Assert.True( _lLavaWeights.EmbedImage( _context, image, ref n_past ) ); - } + SafeLlavaImageEmbedHandle emb = _lLavaWeights.CreateImageEmbeddings(_context, image); + Assert.True( _lLavaWeights.EvalImageEmbed( _context, emb, ref n_past ) ); + } } } diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index ef5453a7..64ca0a0b 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -12,7 +12,14 @@ namespace LLama.Abstractions /// The loaded context for this executor. /// public LLamaContext Context { get; } - + + // LLava Section + public bool IsMultiModal { get; } + public bool MultiModalProject { get; } + public LLavaWeights? ClipModel { get; } + public string ImagePath { get; set; } + + /// /// Asynchronously infers a response from the model. /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index ec72a25a..ba3fcd64 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -64,6 +64,18 @@ namespace LLama /// public LLamaContext Context { get; } + // LLava Section + public bool IsMultiModal + { + get + { + return ClipModel != null && !string.IsNullOrEmpty(ImagePath); + } + } + public bool MultiModalProject { get; } + public LLavaWeights? ClipModel { get; } + public string ImagePath { get; set; } + /// /// Current "mu" value for mirostat sampling /// @@ -86,6 +98,13 @@ namespace LLama _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize); _decoder = new StreamingTokenDecoder(context); } + + public StatefulExecutorBase(LLamaContext context, LLavaWeights lLavaWeights, ILogger? logger = null) : + this( context, logger ) + { + ClipModel = lLavaWeights; + MultiModalProject = true; + } /// /// This API is currently not verified. diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 2a14eeaf..3041b2ef 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -21,6 +21,11 @@ namespace LLama { private bool _is_prompt_run = true; private readonly LLamaToken _llama_token_newline; + + // LLava + private int _EmbedImagePosition = -1; + private SafeLlavaImageEmbedHandle _imageEmbedHandle = null; + private bool _imageInPrompt = false; /// /// @@ -32,6 +37,12 @@ namespace LLama { _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); } + + public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null) + : base(context, clipModel, logger) + { + _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); + } /// public override ExecutorBaseState GetStateData() @@ -107,8 +118,38 @@ namespace LLama { if (_is_prompt_run) { - // When running the first input (prompt) in inteactive mode, we should specially process it. - _embed_inps = Context.Tokenize(text, true).ToList(); + // When running the first input (prompt) in interactive mode, we should specially process it. + if (!this.IsMultiModal) + { + _embed_inps = Context.Tokenize(text, true).ToList(); + } + else + { + // If the prompt contains the tag extract this. + _imageInPrompt = text.Contains(""); + if (_imageInPrompt) + { + if (!string.IsNullOrEmpty(ImagePath)) + { + _imageEmbedHandle = SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, ImagePath); + } + + int imageIndex = text.IndexOf(""); + // Tokenize segment 1 (before tag) + string preImagePrompt = text.Substring(0, imageIndex); + var segment1 = Context.Tokenize(preImagePrompt, true); + // Remember the position to add the image embeddings + _EmbedImagePosition = segment1.Length; + string postImagePrompt = text.Substring(imageIndex + 7); + var segment2 = Context.Tokenize(postImagePrompt, false); + _embed_inps.AddRange(segment1); + _embed_inps.AddRange(segment2); + } + else + { + _embed_inps = Context.Tokenize(text, true).ToList(); + } + } } else { @@ -170,9 +211,30 @@ namespace LLama TryReuseMathingPrefix(); - var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); - if (result != DecodeResult.Ok) - throw new LLamaDecodeError(result); + // Changes to support Multi-Modal LLMs. + // + (DecodeResult, int) header, end, result; + if (IsMultiModal && _EmbedImagePosition > 0) + { + // Previous to Image + header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); + if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); + + // Image + ClipModel.EvalImageEmbed(Context, _imageEmbedHandle, ref _pastTokensCount); + + // Post-image + end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); + + _EmbedImagePosition = -1; + + } + else + { + result = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); + if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); + } + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 2a20d14a..e702f47f 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -23,7 +23,13 @@ namespace LLama private readonly IContextParams _params; private readonly ILogger? _logger; private readonly LLamaBatch _batch; - + + // LLava Section + public bool IsMultiModal => false; + public bool MultiModalProject { get; } + public LLavaWeights ClipModel { get; } + public string ImagePath { get; set; } + /// /// The context used by the executor when running the inference. /// @@ -46,6 +52,7 @@ namespace LLama Context.Dispose(); } + /// public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs index 301fb729..9f93922d 100644 --- a/LLama/LLavaWeights.cs +++ b/LLama/LLavaWeights.cs @@ -19,30 +19,21 @@ public sealed class LLavaWeights : IDisposable return new LLavaWeights(weights); } - /// - /// Embed the image from file into llama context - /// - /// - /// - /// - /// - public bool EmbedImage(LLamaContext ctxLlama, string Image, ref int n_past ) + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, Byte[] image ) { - return NativeHandle.EmbedImage(ctxLlama, Image, ref n_past ); + return NativeHandle.CreateImageEmbeddings(ctxLlama, image ); } - /// - /// Embed the image from binary into llama context. - /// - /// - /// - /// - /// - public bool EmbedImage(LLamaContext ctxLlama, Byte[] Image, ref int n_past ) + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image ) { - return NativeHandle.EmbedImage(ctxLlama, Image, ref n_past ); + return NativeHandle.CreateImageEmbeddings(ctxLlama, image ); } - + + public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) + { + return NativeHandle.EvalImageEmbed( ctxLlama, imageEmbed, ref n_past ); + } + public void Dispose() { NativeHandle.Dispose(); diff --git a/LLama/Native/NativeApi.LLava.cs b/LLama/Native/NativeApi.LLava.cs index 7930e375..895fe43e 100644 --- a/LLama/Native/NativeApi.LLava.cs +++ b/LLama/Native/NativeApi.LLava.cs @@ -9,18 +9,20 @@ public static unsafe partial class NativeApi /// /// Sanity check for clip <-> llava embed size match /// - /// + /// LLama Context + /// Llava Model + /// True if validate successfully [DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)] public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip); /// /// Build an image embed from image file bytes /// - /// - /// - /// - /// - /// + /// SafeHandle to the Clip Model + /// Number of threads + /// Binary image in jpeg format + /// Bytes lenght of the image + /// SafeHandle to the Embeddings [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_bytes", CallingConvention = CallingConvention.Cdecl)] public static extern @@ -30,10 +32,10 @@ public static unsafe partial class NativeApi /// /// Build an image embed from a path to an image filename /// - /// - /// - /// - /// + /// SafeHandle to the Clip Model + /// Number of threads + /// Image filename (jpeg) to generate embeddings + /// SafeHandel to the embeddings [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_filename", CallingConvention = CallingConvention.Cdecl)] public static extern SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHandle ctx_clip, int n_threads, @@ -42,19 +44,19 @@ public static unsafe partial class NativeApi /// /// Free an embedding made with llava_image_embed_make_* /// - /// - /// + /// Embeddings to release [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_free", CallingConvention = CallingConvention.Cdecl)] - public static extern SafeLlavaImageEmbedHandle llava_image_embed_free(IntPtr embed); + public static extern void llava_image_embed_free(IntPtr embed); /// /// Write the image represented by embed into the llama context with batch size n_batch, starting at context /// pos n_past. on completion, n_past points to the next position in the context after the image embed. /// - /// ctx_llama - /// + /// Llama Context + /// Embedding handle + /// True on success [DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)] - public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctc_llama, SafeLlavaImageEmbedHandle embed, + public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past); } \ No newline at end of file diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs index f9544041..bd49a578 100644 --- a/LLama/Native/SafeLlavaModelHandle.cs +++ b/LLama/Native/SafeLlavaModelHandle.cs @@ -11,7 +11,7 @@ using LLama.Exceptions; namespace LLama.Native { /// - /// A reference to a set of llava model weights + /// A reference to a set of llava model weights. /// public sealed class SafeLlavaModelHandle : SafeLLamaHandleBase @@ -36,9 +36,10 @@ namespace LLama.Native /// /// Load a model from the given file path into memory /// - /// - /// - /// + /// MMP File (Multi-Modal Projections) + /// Verbosity level + /// SafeHandle of the Clip Model + /// /// public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity ) { @@ -56,31 +57,37 @@ namespace LLama.Native } /// - /// Embed the image from file in llama context + /// Create the Image Embeddings. /// - /// - /// - /// - /// - public bool EmbedImage(LLamaContext ctxLlama, string image, ref int n_past) + /// LLama Context + /// Image filename (it supports jpeg format only) + /// return the SafeHandle of these embeddings + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) { - var ImageEmbed = SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image); - bool result = NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, ImageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past ); - return result; + return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image); } /// - /// Embed the image from binary in llama context + /// Create the Image Embeddings. + /// + /// LLama Context + /// Image in binary format (it supports jpeg format only) + /// return the SafeHandle of these embeddings + public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image ) + { + return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image ); + } + + /// + /// Evaluates the image embeddings. /// - /// - /// jpeg image + /// Llama Context + /// The current embeddings to evaluate /// - /// - public bool EmbedImage(LLamaContext ctxLlama, Byte[] image, ref int n_past ) + /// True on success + public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) { - var ImageEmbed = SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image ); - bool result = NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, ImageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past ); - return result; + return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.Params.BatchSize, ref n_past ); } /// @@ -95,7 +102,7 @@ namespace LLama.Native /// /// Frees MULTI MODAL PROJECTIONS model / Clip Model /// - /// + /// Internal Pointer to the model [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)] private static extern void clip_free(IntPtr ctx);