diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs index 507f041b..fac10ef1 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs @@ -3,6 +3,7 @@ using LLama.Batched; using LLama.Common; using Spectre.Console; using LLama.Abstractions; +using LLama.Native; namespace LLama.Examples.Examples { @@ -21,9 +22,6 @@ namespace LLama.Examples.Examples var parameters = new ModelParams(modelPath) { - ContextSize = 4096, - Seed = 1337, - GpuLayerCount = 10 }; using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); @@ -69,6 +67,9 @@ namespace LLama.Examples.Examples break; } + // Each prompt with images we clear cache + // When the prompt contains images we clear KV_CACHE to restart conversation + ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 ); int index = 0; foreach (var path in imagePathsWithCurlyBraces) diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index a87a0f37..055a5f13 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -11,7 +11,7 @@ using System.Threading.Tasks; using LLama.Exceptions; using LLama.Extensions; using Microsoft.Extensions.Logging; -using System.Net.Http; + namespace LLama { @@ -136,20 +136,29 @@ namespace LLama text += "\n"; } - var line_inp = Context.Tokenize(text, false); - _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; + if (!this.IsMultiModal) + { + var line_inp = Context.Tokenize(text, false); + _embed_inps.AddRange(line_inp); + args.RemainedTokens -= line_inp.Length; + } + else + { + PreprocessLlava(text, args, false); + } } return Task.CompletedTask; } + /// private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true ) { int usedTokens = 0; + // If the prompt contains the tag extract this. _imageInPrompt = text.Contains(""); - if (_imageInPrompt && ClipModel != null) + if (_imageInPrompt && IsMultiModal ) { foreach (var image in Images) { @@ -170,7 +179,16 @@ namespace LLama } else { - _embed_inps = Context.Tokenize(text, true).ToList(); + if (addBos) + { + _embed_inps = Context.Tokenize(text, true).ToList(); + } + else + { + var line_inp = Context.Tokenize(text, false); + _embed_inps.AddRange(line_inp); + args.RemainedTokens -= line_inp.Length; + } } return Task.CompletedTask; } @@ -239,6 +257,7 @@ namespace LLama _EmbedImagePosition = -1; _imageEmbedHandles.Clear(); + Images.Clear(); } else {