From 43677c511c68f69d264e16f3ae6bf7e7f5fc4d5f Mon Sep 17 00:00:00 2001 From: SignalRT Date: Tue, 26 Mar 2024 23:13:39 +0100 Subject: [PATCH] Change interface to support multiple images and add the capabitlity to render the image in the console --- .../Examples/LlavaInteractiveModeExecute.cs | 96 ++++++++++++++++--- LLama.Examples/LLama.Examples.csproj | 1 + LLama/Abstractions/ILLamaExecutor.cs | 4 +- LLama/LLamaExecutorBase.cs | 5 +- LLama/LLamaInteractExecutor.cs | 70 ++++++++------ LLama/LLamaStatelessExecutor.cs | 5 +- 6 files changed, 132 insertions(+), 49 deletions(-) diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs index 1cb8e8fd..4932a2ae 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs @@ -1,4 +1,7 @@ -using LLama.Common; +using System.Text.RegularExpressions; +using LLama.Batched; +using LLama.Common; +using Spectre.Console; namespace LLama.Examples.Examples { @@ -8,15 +11,15 @@ namespace LLama.Examples.Examples { string multiModalProj = UserSettings.GetMMProjPath(); string modelPath = UserSettings.GetModelPath(); - string imagePath = UserSettings.GetImagePath(); + string modelImage = UserSettings.GetImagePath(); + const int maxTokens = 1024; - var prompt = (await File.ReadAllTextAsync("Assets/vicuna-llava-v16.txt")).Trim(); + var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; var parameters = new ModelParams(modelPath) { ContextSize = 4096, Seed = 1337, - GpuLayerCount = 5 }; using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); @@ -26,26 +29,93 @@ namespace LLama.Examples.Examples var ex = new InteractiveExecutor(context, clipModel ); - ex.ImagePath = imagePath; - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 1024 and the context size is 4096. "); - Console.ForegroundColor = ConsoleColor.White; - - Console.Write(prompt); + Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize ); + Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}."); - var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "USER:" }, MaxTokens = 1024 }; + var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "\nUSER:" }, MaxTokens = maxTokens }; - while (true) + do { + + // Evaluate if we have images + // + var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var imageCount = imageMatches.Count(); + var hasImages = imageCount > 0; + byte[][] imageBytes = null; + + if (hasImages) + { + var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value); + + try + { + imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray(); + } + catch (IOException exception) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.Write( + $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); + Console.Write($"{exception.Message}"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Please try again."); + break; + } + + + int index = 0; + foreach (var path in imagePathsWithCurlyBraces) + { + // First image replace to tag "); + else + prompt = prompt.Replace(path, ""); + } + + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message."); + Console.WriteLine(); + + foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes))) + { + consoleImage.MaxWidth = 50; + AnsiConsole.Write(consoleImage); + } + + Console.WriteLine(); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine($"The images were scaled down for the console only, the model gets full versions."); + Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu."); + Console.WriteLine(); + + + // Initilize Images in executor + // + ex.ImagePaths = imagePaths.ToList(); + } + + Console.ForegroundColor = Color.White; await foreach (var text in ex.InferAsync(prompt, inferenceParams)) { Console.Write(text); } + Console.Write(" "); Console.ForegroundColor = ConsoleColor.Green; prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; + Console.WriteLine(); + + // let the user finish with exit + // + if (prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase)) + break; + } + while(true); } } } diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index c6667c5a..69f74fff 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -19,6 +19,7 @@ + diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 51e613c3..ee4cf512 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -25,9 +25,9 @@ namespace LLama.Abstractions public LLavaWeights? ClipModel { get; } /// - /// Image filename and path (jpeg images). + /// List of images: Image filename and path (jpeg images). /// - public string? ImagePath { get; set; } + public List ImagePaths { get; set; } /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index b31d087a..52b38e18 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -71,7 +71,7 @@ namespace LLama { get { - return ClipModel != null && ImagePath != null; + return ClipModel != null; } } @@ -79,7 +79,7 @@ namespace LLama public LLavaWeights? ClipModel { get; } /// - public string? ImagePath { get; set; } + public List ImagePaths { get; set; } /// /// Current "mu" value for mirostat sampling @@ -95,6 +95,7 @@ namespace LLama /// protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) { + ImagePaths = new List(); _logger = logger; Context = context; _pastTokensCount = 0; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 3041b2ef..21bb8dcc 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -24,7 +24,7 @@ namespace LLama // LLava private int _EmbedImagePosition = -1; - private SafeLlavaImageEmbedHandle _imageEmbedHandle = null; + private List _imageEmbedHandles = new List(); private bool _imageInPrompt = false; /// @@ -125,30 +125,7 @@ namespace LLama } else { - // If the prompt contains the tag extract this. - _imageInPrompt = text.Contains(""); - if (_imageInPrompt) - { - if (!string.IsNullOrEmpty(ImagePath)) - { - _imageEmbedHandle = SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, ImagePath); - } - - int imageIndex = text.IndexOf(""); - // Tokenize segment 1 (before tag) - string preImagePrompt = text.Substring(0, imageIndex); - var segment1 = Context.Tokenize(preImagePrompt, true); - // Remember the position to add the image embeddings - _EmbedImagePosition = segment1.Length; - string postImagePrompt = text.Substring(imageIndex + 7); - var segment2 = Context.Tokenize(postImagePrompt, false); - _embed_inps.AddRange(segment1); - _embed_inps.AddRange(segment2); - } - else - { - _embed_inps = Context.Tokenize(text, true).ToList(); - } + PreprocessLlava(text, args, true ); } } else @@ -157,6 +134,7 @@ namespace LLama { text += "\n"; } + var line_inp = Context.Tokenize(text, false); _embed_inps.AddRange(line_inp); args.RemainedTokens -= line_inp.Length; @@ -165,6 +143,37 @@ namespace LLama return Task.CompletedTask; } + private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true ) + { + int usedTokens = 0; + // If the prompt contains the tag extract this. + _imageInPrompt = text.Contains(""); + if (_imageInPrompt) + { + foreach (var image in ImagePaths) + { + _imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, image ) ); + } + + int imageIndex = text.IndexOf(""); + // Tokenize segment 1 (before tag) + string preImagePrompt = text.Substring(0, imageIndex); + var segment1 = Context.Tokenize(preImagePrompt, addBos ); + // Remember the position to add the image embeddings + _EmbedImagePosition = segment1.Length; + string postImagePrompt = text.Substring(imageIndex + 7); + var segment2 = Context.Tokenize(postImagePrompt, false); + _embed_inps.AddRange(segment1); + _embed_inps.AddRange(segment2); + usedTokens += (segment1.Length + segment2.Length); + } + else + { + _embed_inps = Context.Tokenize(text, true).ToList(); + } + return Task.CompletedTask; + } + /// /// Return whether to break the generation. /// @@ -216,18 +225,19 @@ namespace LLama (DecodeResult, int) header, end, result; if (IsMultiModal && _EmbedImagePosition > 0) { - // Previous to Image + // Tokens previous to the images header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); - // Image - ClipModel.EvalImageEmbed(Context, _imageEmbedHandle, ref _pastTokensCount); + // Images + foreach( var image in _imageEmbedHandles ) + ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount); - // Post-image + // Post-image Tokens end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); _EmbedImagePosition = -1; - + _imageEmbedHandles.Clear(); } else { diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index e702f47f..9d705af1 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -27,8 +27,8 @@ namespace LLama // LLava Section public bool IsMultiModal => false; public bool MultiModalProject { get; } - public LLavaWeights ClipModel { get; } - public string ImagePath { get; set; } + public LLavaWeights? ClipModel { get; } + public List ImagePaths { get; set; } /// /// The context used by the executor when running the inference. @@ -43,6 +43,7 @@ namespace LLama /// public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { + ImagePaths = new List(); _weights = weights; _params = @params; _logger = logger;