Browse Source

feat: add get-embedding api to LLamaModel.

tags/v0.2.2
Yaohui Liu 2 years ago
parent
commit
0958bbac2c
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
6 changed files with 100 additions and 8 deletions
  1. +22
    -0
      LLama.Examples/GetEmbeddings.cs
  2. +9
    -4
      LLama.Examples/Program.cs
  3. +3
    -2
      LLama/GptModel.cs
  4. +64
    -0
      LLama/LLamaEmbedder.cs
  5. +1
    -1
      LLama/Native/SafeLLamaContextHandle.cs
  6. +1
    -1
      LLama/Native/SafeLLamaHandleBase.cs

+ 22
- 0
LLama.Examples/GetEmbeddings.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace LLama.Examples
{
public class GetEmbeddings
{
LLamaEmbedder _embedder;
public GetEmbeddings(string modelPath)
{
_embedder = new LLamaEmbedder(new LLamaParams(model: modelPath));
}

public void Run(string text)
{
Console.WriteLine(string.Join(", ", _embedder.GetEmbeddings(text)));
}
}
}

+ 9
- 4
LLama.Examples/Program.cs View File

@@ -2,11 +2,11 @@
using LLama.Examples; using LLama.Examples;
using LLama.Types; using LLama.Types;


int choice = 3;
int choice = 0;


if(choice == 0) if(choice == 0)
{ {
ChatSession chat = new(@"C:\Users\haipi\Source\repos\ggml-model-q4_0.bin", @"C:\Users\haipi\Source\repos\SciSharp\LLamaSharp\LLama.Examples\Assets\chat-with-bob.txt", new string[] { "User:" });
ChatSession chat = new(@"<Your model file path>", @"<Your prompt file path>", new string[] { "User:" });
chat.Run(); chat.Run();
} }
else if(choice == 1) else if(choice == 1)
@@ -22,6 +22,11 @@ else if(choice == 2)
else if (choice == 3) // quantization else if (choice == 3) // quantization
{ {
Quantize q = new Quantize(); Quantize q = new Quantize();
q.Run(@"D:\development\llama\weights\LLaMA\7B\ggml-model-f16.bin",
@"D:\development\llama\weights\LLaMA\7B\ggml-model-q4_1.bin", "q4_1");
q.Run(@"<Your src model file path>",
@"<Your dst model file path>", "q4_1");
}
else if (choice == 4) // quantization
{
GetEmbeddings em = new GetEmbeddings(@"<Your model file path>");
em.Run("Hello, what is python?");
} }

+ 3
- 2
LLama/GptModel.cs View File

@@ -42,6 +42,7 @@ namespace LLama
bool _first_time_chat = true; bool _first_time_chat = true;


public string Name { get; set; } public string Name { get; set; }
public SafeLLamaContextHandle NativeHandle => _ctx;


public LLamaModel(string model_path, string model_name, bool echo_input = false, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1, public LLamaModel(string model_path, string model_name, bool echo_input = false, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1,
int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_parts = -1, int n_ctx = 512, int n_batch = 512, int n_keep = 0,
@@ -70,7 +71,7 @@ namespace LLama
_ctx = Utils.llama_init_from_gpt_params(ref _params); _ctx = Utils.llama_init_from_gpt_params(ref _params);


// Add a space in front of the first character to match OG llama tokenizer behavior // Add a space in front of the first character to match OG llama tokenizer behavior
_params.prompt.Insert(0, " ");
_params.prompt = _params.prompt.Insert(0, " ");
_session_tokens = new List<llama_token>(); _session_tokens = new List<llama_token>();


_path_session = @params.path_session; _path_session = @params.path_session;
@@ -223,7 +224,7 @@ namespace LLama
_params.prompt = prompt; _params.prompt = prompt;
if (!_params.prompt.EndsWith(" ")) if (!_params.prompt.EndsWith(" "))
{ {
_params.prompt.Insert(0, " ");
_params.prompt = _params.prompt.Insert(0, " ");
} }
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true); _embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true);
if (_embed_inp.Count > _n_ctx - 4) if (_embed_inp.Count > _n_ctx - 4)


+ 64
- 0
LLama/LLamaEmbedder.cs View File

@@ -0,0 +1,64 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;

namespace LLama
{
public class LLamaEmbedder
{
SafeLLamaContextHandle _ctx;

/// <summary>
/// Warning: must ensure the original model has params.embedding = true;
/// </summary>
/// <param name="ctx"></param>
internal LLamaEmbedder(SafeLLamaContextHandle ctx)
{
_ctx = ctx;
}

public LLamaEmbedder(LLamaParams @params)
{
@params.embedding = true;
_ctx = Utils.llama_init_from_gpt_params(ref @params);
}

public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true)
{
if(n_thread == -1)
{
n_thread = Math.Max(Environment.ProcessorCount / 2, 1);
}
int n_past = 0;
if (add_bos)
{
text = text.Insert(0, " ");
}
var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos);

// TODO(Rinne): deal with log of prompt

if (embed_inp.Count > 0)
{
var embed_inp_array = embed_inp.ToArray();
if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, n_thread) != 0)
{
throw new RuntimeError("Failed to eval.");
}
}

int n_embed = NativeApi.llama_n_embd(_ctx);
var embeddings = NativeApi.llama_get_embeddings(_ctx);
if(embeddings == null)
{
return new float[0];
}
var span = new Span<float>(embeddings, n_embed);
float[] res = new float[n_embed];
span.CopyTo(res.AsSpan());
return res;
}
}
}

+ 1
- 1
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -5,7 +5,7 @@ using System.Text;


namespace LLama.Native namespace LLama.Native
{ {
internal class SafeLLamaContextHandle: SafeLLamaHandleBase
public class SafeLLamaContextHandle: SafeLLamaHandleBase
{ {
protected SafeLLamaContextHandle() protected SafeLLamaContextHandle()
{ {


+ 1
- 1
LLama/Native/SafeLLamaHandleBase.cs View File

@@ -5,7 +5,7 @@ using System.Text;


namespace LLama.Native namespace LLama.Native
{ {
internal abstract class SafeLLamaHandleBase: SafeHandle
public abstract class SafeLLamaHandleBase: SafeHandle
{ {
private protected SafeLLamaHandleBase() private protected SafeLLamaHandleBase()
: base(IntPtr.Zero, ownsHandle: true) : base(IntPtr.Zero, ownsHandle: true)


Loading…
Cancel
Save