using LLama.Native;
using System;
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;
namespace LLama
{
///
/// The embedder for LLama, which supports getting embeddings from text.
///
public sealed class LLamaEmbedder
: IDisposable
{
///
/// Dimension of embedding vectors
///
public int EmbeddingSize => Context.EmbeddingSize;
///
/// LLama Context
///
public LLamaContext Context { get; }
///
/// Create a new embedder, using the given LLamaWeights
///
///
///
///
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
@params.EmbeddingMode = true;
Context = weights.CreateContext(@params, logger);
}
///
/// Get the embeddings of the text.
///
///
/// unused
/// Add bos to the text.
/// unused
///
///
[Obsolete("'threads' and 'encoding' parameters are no longer used")]
// ReSharper disable once MethodOverloadWithOptionalParameter
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
{
return GetEmbeddings(text, addBos);
}
///
/// Get the embeddings of the text.
///
///
///
///
public float[] GetEmbeddings(string text)
{
return GetEmbeddings(text, true);
}
///
/// Get the embeddings of the text.
///
///
/// Add bos to the text.
///
///
public float[] GetEmbeddings(string text, bool addBos)
{
var embed_inp_array = Context.Tokenize(text, addBos);
// TODO(Rinne): deal with log of prompt
if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array, 0);
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty();
return embeddings.ToArray();
}
///
///
///
public void Dispose()
{
Context.Dispose();
}
}
}