Browse Source

Standardizing Image Data implementation

pull/653/head
Zoli Somogyi 1 year ago
parent
commit
e991e631f9
5 changed files with 85 additions and 26 deletions
  1. +5
    -1
      LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
  2. +47
    -6
      LLama/Abstractions/ILLamaExecutor.cs
  3. +9
    -7
      LLama/LLamaExecutorBase.cs
  4. +14
    -8
      LLama/LLamaInteractExecutor.cs
  5. +10
    -4
      LLama/LLamaStatelessExecutor.cs

+ 5
- 1
LLama.Examples/Examples/LlavaInteractiveModeExecute.cs View File

@@ -2,6 +2,7 @@
using LLama.Batched;
using LLama.Common;
using Spectre.Console;
using LLama.Abstractions;

namespace LLama.Examples.Examples
{
@@ -99,7 +100,10 @@ namespace LLama.Examples.Examples

// Initilize Images in executor
//
ex.ImagePaths = imagePaths.ToList();
foreach (var image in imagePaths)
{
ex.Images.Add(new ImageData(ImageData.DataType.ImagePath, image));
}
}

Console.ForegroundColor = Color.White;


+ 47
- 6
LLama/Abstractions/ILLamaExecutor.cs View File

@@ -22,14 +22,13 @@ namespace LLama.Abstractions
/// <summary>
/// Muti-Modal Projections / Clip Model weights
/// </summary>
public LLavaWeights? ClipModel { get; }
public LLavaWeights? ClipModel { get; }
/// <summary>
/// List of images: Image filename and path (jpeg images).
/// List of images: Image filen path, uri or image byte array. See ImageData.
/// </summary>
public List<string> ImagePaths { get; set; }
public List<ImageData> Images { get; }

/// <summary>
/// Asynchronously infers a response from the model.
/// </summary>
@@ -39,4 +38,46 @@ namespace LLama.Abstractions
/// <returns></returns>
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
}

/// <summary>
/// Holds image data
/// </summary>
public class ImageData
{
/// <summary>
/// constructor
/// </summary>
/// <param name="type"></param>
/// <param name="data"></param>
public ImageData(DataType type, object data) { Type = type; Data = data; }

/// <summary>
/// the possible types of image data
/// </summary>
public enum DataType
{
/// <summary>
/// file path
/// </summary>
ImagePath,
/// <summary>
/// byte array
/// </summary>
ImageBytes,
/// <summary>
/// uri
/// </summary>
ImageURL
}

/// <summary>
/// the type of this image data
/// </summary>
public DataType Type { get; set; }

/// <summary>
/// the image data (string, byte array or uri)
/// </summary>
public object? Data { get; set; }
}
}

+ 9
- 7
LLama/LLamaExecutorBase.cs View File

@@ -76,13 +76,10 @@ namespace LLama
}
/// <inheritdoc />
public LLavaWeights? ClipModel { get; }
/// <inheritdoc />
public List<string> ImagePaths { get; set; }
public LLavaWeights? ClipModel { get; }

/// <inheritdoc />
public List<byte[]> ImageBytes { get; set; }
public List<ImageData> Images { get; set; }

/// <summary>
/// Current "mu" value for mirostat sampling
@@ -98,8 +95,7 @@ namespace LLama
/// <param name="logger"></param>
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
{
ImagePaths = new List<string>();
ImageBytes = new List<byte[]>();
Images = new List<ImageData>();
_logger = logger;
Context = context;
_pastTokensCount = 0;
@@ -109,6 +105,12 @@ namespace LLama
_decoder = new StreamingTokenDecoder(context);
}
/// <summary>
///
/// </summary>
/// <param name="context"></param>
/// <param name="lLavaWeights"></param>
/// <param name="logger"></param>
public StatefulExecutorBase(LLamaContext context, LLavaWeights lLavaWeights, ILogger? logger = null) :
this( context, logger )
{


+ 14
- 8
LLama/LLamaInteractExecutor.cs View File

@@ -148,16 +148,22 @@ namespace LLama
int usedTokens = 0;
// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt)
if (_imageInPrompt && ClipModel != null)
{
foreach (var image in ImagePaths)
foreach (var image in Images)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromFileName( ClipModel.NativeHandle, Context, image ) );
}

foreach (var image in ImageBytes)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel.NativeHandle, Context, image));
if (image.Type == ImageData.DataType.ImagePath && image.Data != null)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromFileName(ClipModel.NativeHandle, Context, image.Data.ToString()));
}
else if (image.Type == ImageData.DataType.ImageBytes && image.Data != null)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel.NativeHandle, Context, (byte[])image.Data));
}
else if (image.Type == ImageData.DataType.ImageURL && image.Data != null)
{
throw new NotImplementedException();
}
}

int imageIndex = text.IndexOf("<image>");


+ 10
- 4
LLama/LLamaStatelessExecutor.cs View File

@@ -26,10 +26,16 @@ namespace LLama
// LLava Section
public bool IsMultiModal => false;

/// <inheritdoc />
public bool MultiModalProject { get; }
public LLavaWeights? ClipModel { get; }
public List<string> ImagePaths { get; set; }

/// <inheritdoc />
public LLavaWeights? ClipModel { get; }

/// <inheritdoc />
public List<ImageData> Images { get; set; }

/// <summary>
/// The context used by the executor when running the inference.
/// </summary>
@@ -43,7 +49,7 @@ namespace LLama
/// <param name="logger"></param>
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
ImagePaths = new List<string>();
Images = new List<ImageData>();
_weights = weights;
_params = @params;
_logger = logger;


Loading…
Cancel
Save