Browse Source

Merge branch 'deps/sk-rc3' of https://github.com/xbotter/LLamaSharp into deps/sk-rc3

tags/0.9.1
xbotter 1 year ago
parent
commit
8766fb1b03
No known key found for this signature in database GPG Key ID: A3F32F44E9F160E1
32 changed files with 1378 additions and 419 deletions
  1. +1
    -1
      .github/workflows/compile.yml
  2. +24
    -0
      LLama.Examples/Assets/chat-with-bob.json
  3. +24
    -0
      LLama.Examples/Assets/chat-with-kunkun-chinese.json
  4. +102
    -49
      LLama.Examples/Examples/ChatChineseGB2312.cs
  5. +48
    -31
      LLama.Examples/Examples/ChatSessionStripRoleName.cs
  6. +98
    -0
      LLama.Examples/Examples/ChatSessionWithHistory.cs
  7. +45
    -31
      LLama.Examples/Examples/ChatSessionWithRoleName.cs
  8. +11
    -2
      LLama.Examples/Examples/LoadAndSaveSession.cs
  9. +2
    -1
      LLama.Examples/Examples/Runner.cs
  10. +7
    -1
      LLama.Examples/LLama.Examples.csproj
  11. +1
    -1
      LLama.KernelMemory/LLamaSharp.KernelMemory.csproj
  12. +1
    -2
      LLama.Unittest/GrammarParserTest.cs
  13. +4
    -4
      LLama.Unittest/LLama.Unittest.csproj
  14. +8
    -2
      LLama.Unittest/ModelsParamsTests.cs
  15. +5
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  16. +8
    -2
      LLama.Web/Common/InferenceOptions.cs
  17. +19
    -16
      LLama.WebAPI/Services/StatefulChatService.cs
  18. +6
    -0
      LLama/Abstractions/IInferenceParams.cs
  19. +24
    -0
      LLama/Abstractions/IModelParams.cs
  20. +450
    -200
      LLama/ChatSession.cs
  21. +34
    -6
      LLama/Common/ChatHistory.cs
  22. +4
    -0
      LLama/Common/InferenceParams.cs
  23. +13
    -35
      LLama/Common/ModelParams.cs
  24. +12
    -0
      LLama/LLamaContext.cs
  25. +17
    -9
      LLama/LLamaInstructExecutor.cs
  26. +18
    -10
      LLama/LLamaInteractExecutor.cs
  27. +1
    -0
      LLama/LLamaSharp.csproj
  28. +19
    -10
      LLama/LLamaStatelessExecutor.cs
  29. +34
    -5
      LLama/Native/LLamaTokenDataArray.cs
  30. +128
    -0
      LLama/Sampling/BaseSamplingPipeline.cs
  31. +149
    -0
      LLama/Sampling/DefaultSamplingPipeline.cs
  32. +61
    -0
      LLama/Sampling/ISamplingPipeline.cs

+ 1
- 1
.github/workflows/compile.yml View File

@@ -140,7 +140,7 @@ jobs:
- build: 'arm64'
defines: '-DCMAKE_OSX_ARCHITECTURES=arm64'
- build: 'x64'
defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF'
defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=OFF -DLLAMA_AVX=ON -DLLAMA_AVX2=ON'
runs-on: macos-latest
steps:
- uses: actions/checkout@v3


+ 24
- 0
LLama.Examples/Assets/chat-with-bob.json View File

@@ -0,0 +1,24 @@
{
"messages": [
{
"author_role": "System",
"content": "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision."
},
{
"author_role": "User",
"content": "Hello, Bob."
},
{
"author_role": "Assistant",
"content": "Hello. How may I help you today?"
},
{
"author_role": "User",
"content": "Please tell me the largest city in Europe."
},
{
"author_role": "Assistant",
"content": "Sure. The largest city in Europe is Istanbul, Turkey."
}
]
}

+ 24
- 0
LLama.Examples/Assets/chat-with-kunkun-chinese.json View File

@@ -0,0 +1,24 @@
{
"messages": [
{
"author_role": "System",
"content": "下面是一段你和用户的对话,你叫坤坤,是一个在各方面都拥有丰富经验的助理,你非常乐于回答用户的问题和帮助用户。"
},
{
"author_role": "User",
"content": "你好,坤坤。"
},
{
"author_role": "Assistant",
"content": "你好,有什么我能帮助你的吗?"
},
{
"author_role": "User",
"content": "中国的首都是哪座城市?"
},
{
"author_role": "Assistant",
"content": "中国的首都是北京市。"
}
]
}

+ 102
- 49
LLama.Examples/Examples/ChatChineseGB2312.cs View File

@@ -1,69 +1,122 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Text;
using LLama.Common;

namespace LLama.Examples.Examples
namespace LLama.Examples.Examples;

public class ChatChineseGB2312
{
public class ChatChineseGB2312
private static string ConvertEncoding(string input, Encoding original, Encoding target)
{
byte[] bytes = original.GetBytes(input);
var convertedBytes = Encoding.Convert(original, target, bytes);
return target.GetString(convertedBytes);
}

public static async Task Run()
{
private static string ConvertFromEncodingToAnother(string input, Encoding original, Encoding target)
// Register provider for GB2312 encoding
Encoding.RegisterProvider(CodePagesEncodingProvider.Instance);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" +
" to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers.");
Console.ForegroundColor = ConsoleColor.White;

Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5,
Encoding = Encoding.UTF8
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

ChatSession session;
if (Directory.Exists("Assets/chat-with-kunkun-chinese"))
{
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Loading session from disk.");
Console.ForegroundColor = ConsoleColor.White;

session = new ChatSession(executor);
session.LoadSession("Assets/chat-with-kunkun-chinese");
}
else
{
byte[] bytes = original.GetBytes(input);
var convertedBytes = Encoding.Convert(original, target, bytes);
return target.GetString(convertedBytes);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

session = new ChatSession(executor, chatHistory);
}

public static async Task Run()
session
.WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤"));

InferenceParams inferenceParams = new InferenceParams()
{
Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); // Register gb2312 encoding
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-kunkun-chinese.txt", encoding: Encoding.GetEncoding("gb2312")).Trim();
prompt = ConvertFromEncodingToAnother(prompt, Encoding.GetEncoding("gb2312"), Encoding.UTF8);
Temperature = 0.9f,
AntiPrompts = new List<string> { "用户:" }
};

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 20,
Encoding = Encoding.UTF8
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

var session = new ChatSession(executor).WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户"));
// show the prompt
Console.ForegroundColor = ConsoleColor.White;
Console.Write("用户:");
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("This example shows how to use Chinese with gb2312 encoding, which is common in windows. It's recommended" +
" to use https://huggingface.co/hfl/chinese-alpaca-2-7b-gguf/blob/main/ggml-model-q5_0.gguf, which has been verified by LLamaSharp developers.");
Console.ForegroundColor = ConsoleColor.White;
while (userInput != "exit")
{
// Convert the encoding from gb2312 to utf8 for the language model
// and later saving to the history json file.
userInput = ConvertEncoding(userInput, Encoding.GetEncoding("gb2312"), Encoding.UTF8);

// show the prompt
Console.Write(prompt);
while (true)
if (userInput == "save")
{
await foreach (var text in session.ChatAsync(prompt, new InferenceParams()
session.SaveSession("Assets/chat-with-kunkun-chinese");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
else if (userInput == "regenerate")
{
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Regenerating last response ...");

await foreach (
var text
in session.RegenerateAssistantMessageAsync(
inferenceParams))
{
Temperature = 0.3f,
TopK = 5,
TopP = 0.85f,
AntiPrompts = new List<string> { "用户:" },
MaxTokens = 2048,
RepeatPenalty = 1.05f
}))
Console.ForegroundColor = ConsoleColor.White;

// Convert the encoding from utf8 to gb2312 for the console output.
Console.Write(ConvertEncoding(text, Encoding.UTF8, Encoding.GetEncoding("gb2312")));
}
}
else
{
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
//Console.Write(text);
Console.Write(ConvertFromEncodingToAnother(text, Encoding.UTF8, Encoding.GetEncoding("gb2312")));
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}

Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
}

Console.ForegroundColor = ConsoleColor.Green;
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}
}

+ 48
- 31
LLama.Examples/Examples/ChatSessionStripRoleName.cs View File

@@ -1,44 +1,61 @@
using LLama.Common;

namespace LLama.Examples.Examples
namespace LLama.Examples.Examples;

public class ChatSessionStripRoleName
{
public class ChatSessionStripRoleName
public static async Task Run()
{
public static async Task Run()
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8));

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started. The role names won't be printed.");
Console.ForegroundColor = ConsoleColor.White;
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

// show the prompt
Console.Write(prompt);
while (true)
{
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}
ChatSession session = new(executor, chatHistory);
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:" },
redundancyLength: 8));

InferenceParams inferenceParams = new InferenceParams()
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "User:" }
};

Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}

Console.ForegroundColor = ConsoleColor.Green;
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}
}

+ 98
- 0
LLama.Examples/Examples/ChatSessionWithHistory.cs View File

@@ -0,0 +1,98 @@
using LLama.Common;

namespace LLama.Examples.Examples;

public class ChatSessionWithHistory
{
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

ChatSession session;
if (Directory.Exists("Assets/chat-with-bob"))
{
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Loading session from disk.");
Console.ForegroundColor = ConsoleColor.White;

session = new ChatSession(executor);
session.LoadSession("Assets/chat-with-bob");
}
else
{
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

session = new ChatSession(executor, chatHistory);
}

session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:" },
redundancyLength: 8));

InferenceParams inferenceParams = new InferenceParams()
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "User:" }
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
if (userInput == "save")
{
session.SaveSession("Assets/chat-with-bob");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
else if (userInput == "regenerate")
{
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Regenerating last response ...");

await foreach (
var text
in session.RegenerateAssistantMessageAsync(
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}
}
else
{
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}
}

Console.ForegroundColor = ConsoleColor.Green;
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}
}

+ 45
- 31
LLama.Examples/Examples/ChatSessionWithRoleName.cs View File

@@ -1,44 +1,58 @@
using LLama.Common;

namespace LLama.Examples.Examples
namespace LLama.Examples.Examples;

public class ChatSessionWithRoleName
{
public class ChatSessionWithRoleName
public static async Task Run()
{
public static async Task Run()
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var session = new ChatSession(executor);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result.");
Console.ForegroundColor = ConsoleColor.White;
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

// show the prompt
Console.Write(prompt);
while (true)
{
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}
ChatSession session = new(executor, chatHistory);

InferenceParams inferenceParams = new InferenceParams()
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "User:" }
};

Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}

Console.ForegroundColor = ConsoleColor.Green;
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}
}

+ 11
- 2
LLama.Examples/Examples/LoadAndSaveSession.cs View File

@@ -1,4 +1,5 @@
using LLama.Common;
using DocumentFormat.OpenXml.Bibliography;
using LLama.Common;

namespace LLama.Examples.Examples
{
@@ -30,7 +31,15 @@ namespace LLama.Examples.Examples
Console.Write(prompt);
while (true)
{
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, prompt),
new InferenceParams()
{
Temperature = 0.6f,
AntiPrompts = new List<string> { "User:" }
}))
{
Console.Write(text);
}


+ 2
- 1
LLama.Examples/Examples/Runner.cs View File

@@ -6,8 +6,10 @@ public class Runner
{
private static readonly Dictionary<string, Func<Task>> Examples = new()
{
{ "Run a chat session with history.", ChatSessionWithHistory.Run },
{ "Run a chat session without stripping the role names.", ChatSessionWithRoleName.Run },
{ "Run a chat session with the role names stripped.", ChatSessionStripRoleName.Run },
{ "Run a chat session in Chinese GB2312 encoding", ChatChineseGB2312.Run },
{ "Interactive mode chat by using executor.", InteractiveModeExecute.Run },
{ "Instruct mode chat by using executor.", InstructModeExecute.Run },
{ "Stateless mode chat by using executor.", StatelessModeExecute.Run },
@@ -23,7 +25,6 @@ public class Runner
{ "Coding Assistant.", CodingAssistant.Run },
{ "Batch Decoding.", BatchedDecoding.Run },
{ "SK Kernel Memory.", KernelMemory.Run },
{ "Chinese gb2312 chat", ChatChineseGB2312.Run },
{ "Exit", async () => Environment.Exit(0) }
};



+ 7
- 1
LLama.Examples/LLama.Examples.csproj View File

@@ -2,7 +2,7 @@
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<Platforms>AnyCPU;x64</Platforms>
@@ -27,6 +27,12 @@
</ItemGroup>

<ItemGroup>
<None Update="Assets\chat-with-bob.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\chat-with-kunkun-chinese.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\chat-with-bob.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>


+ 1
- 1
LLama.KernelMemory/LLamaSharp.KernelMemory.csproj View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net6.0;net7.0</TargetFrameworks>
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<Version>0.8.0</Version>


+ 1
- 2
LLama.Unittest/GrammarParserTest.cs View File

@@ -1,5 +1,4 @@
using System.Text;
using LLama.Exceptions;
using LLama.Exceptions;
using LLama.Native;
using LLama.Grammars;



+ 4
- 4
LLama.Unittest/LLama.Unittest.csproj View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net8.0</TargetFramework>
<RootNamespace>LLama.Unittest</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings>
<Platforms>AnyCPU;x64</Platforms>
@@ -15,8 +15,8 @@
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.8.0" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="xunit" Version="2.6.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.4">
<PackageReference Include="xunit" Version="2.6.3" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.5">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>


+ 8
- 2
LLama.Unittest/ModelsParamsTests.cs View File

@@ -1,4 +1,5 @@
using LLama.Common;
using System.Text.Json;

namespace LLama.Unittest
{
@@ -16,14 +17,19 @@ namespace LLama.Unittest
TensorSplits = { [0] = 3 }
};

var json = System.Text.Json.JsonSerializer.Serialize(expected);
var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json)!;
var json = JsonSerializer.Serialize(expected);
var actual = JsonSerializer.Deserialize<ModelParams>(json)!;

// Cannot compare splits with default equality, check they are sequence equal and then set to null
Assert.Equal((IEnumerable<float>)expected.TensorSplits, expected.TensorSplits);
actual.TensorSplits = null!;
expected.TensorSplits = null!;

// Check encoding is the same
var b1 = expected.Encoding.GetBytes("Hello");
var b2 = actual.Encoding.GetBytes("Hello");
Assert.True(b1.SequenceEqual(b2));

Assert.Equal(expected, actual);
}



+ 5
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -1,5 +1,6 @@
using System.Diagnostics;
using LLama.Common;
using LLama.Sampling;
using Xunit.Abstractions;

namespace LLama.Unittest
@@ -30,10 +31,13 @@ namespace LLama.Unittest
[Fact]
public async Task Stateless()
{
// Create a custom pipeline that mimics the default pipeline
var pipeline = new DefaultSamplingPipeline();

var executor = new StatelessExecutor(_weights, _params);

const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };

var timer = new Stopwatch();
timer.Start();


+ 8
- 2
LLama.Web/Common/InferenceOptions.cs View File

@@ -1,6 +1,9 @@
using LLama.Common;
#nullable enable

using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Web.Common
{
@@ -64,6 +67,9 @@ namespace LLama.Web.Common
/// <summary>
/// A grammar to constrain possible tokens
/// </summary>
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}
}

+ 19
- 16
LLama.WebAPI/Services/StatefulChatService.cs View File

@@ -11,8 +11,7 @@ public class StatefulChatService : IDisposable
private readonly LLamaContext _context;
private bool _continue = false;

private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n"
+ "User: ";
private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.";

public StatefulChatService(IConfiguration configuration)
{
@@ -25,7 +24,9 @@ public class StatefulChatService : IDisposable
using var weights = LLamaWeights.LoadFromFile(@params);

_context = new LLamaContext(weights, @params);

_session = new ChatSession(new InteractiveExecutor(_context));
_session.History.AddMessage(Common.AuthorRole.System, SystemPrompt);
}

public void Dispose()
@@ -35,10 +36,8 @@ public class StatefulChatService : IDisposable

public async Task<string> Send(SendMessageInput input)
{
var userInput = input.Text;
if (!_continue)
{
userInput = SystemPrompt + userInput;
Console.Write(SystemPrompt);
_continue = true;
}
@@ -47,11 +46,14 @@ public class StatefulChatService : IDisposable
Console.Write(input.Text);

Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});
var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});

var result = "";
await foreach (var output in outputs)
{
@@ -64,10 +66,8 @@ public class StatefulChatService : IDisposable

public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
{
var userInput = input.Text;
if (!_continue)
{
userInput = SystemPrompt + userInput;
Console.Write(SystemPrompt);
_continue = true;
}
@@ -76,11 +76,14 @@ public class StatefulChatService : IDisposable
Console.Write(input.Text);

Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});
var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text)
, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});

await foreach (var output in outputs)
{
Console.Write(output);


+ 6
- 0
LLama/Abstractions/IInferenceParams.cs View File

@@ -1,6 +1,7 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Abstractions
{
@@ -108,5 +109,10 @@ namespace LLama.Abstractions
/// Grammar to constrain possible tokens
/// </summary>
SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }
}
}

+ 24
- 0
LLama/Abstractions/IModelParams.cs View File

@@ -3,6 +3,9 @@ using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Common;
using LLama.Native;

namespace LLama.Abstractions
@@ -105,6 +108,7 @@ namespace LLama.Abstractions
/// <summary>
/// A fixed size array to set the tensor splits across multiple GPUs
/// </summary>
[JsonConverter(typeof(TensorSplitsCollectionConverter))]
public sealed class TensorSplitsCollection
: IEnumerable<float>
{
@@ -174,4 +178,24 @@ namespace LLama.Abstractions
}
#endregion
}

/// <summary>
/// A JSON converter for <see cref="TensorSplitsCollection"/>
/// </summary>
public class TensorSplitsCollectionConverter
: JsonConverter<TensorSplitsCollection>
{
/// <inheritdoc/>
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
return new TensorSplitsCollection(arr);
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}

+ 450
- 200
LLama/ChatSession.cs View File

@@ -1,246 +1,496 @@
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Common;
using static LLama.InteractiveExecutor;

namespace LLama
namespace LLama;

/// <summary>
/// The main chat session class.
/// </summary>
public class ChatSession
{
private const string _modelStateFilename = "ModelState.st";
private const string _executorStateFilename = "ExecutorState.json";
private const string _hsitoryFilename = "ChatHistory.json";

/// <summary>
/// The main chat session class.
/// </summary>
public class ChatSession
{
private readonly ILLamaExecutor _executor;
private readonly ChatHistory _history;

private const string _executorStateFilename = "ExecutorState.json";
private const string _modelStateFilename = "ModelState.st";

/// <summary>
/// The executor for this session.
/// </summary>
public ILLamaExecutor Executor => _executor;
/// <summary>
/// The chat history for this session.
/// </summary>
public ChatHistory History => _history;
/// <summary>
/// The history transform used in this session.
/// </summary>
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public List<ITextTransform> InputTransformPipeline { get; set; } = new();
/// <summary>
/// The output transform used in this session.
/// </summary>
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
///
/// </summary>
/// <param name="executor">The executor for this session</param>
public ChatSession(ILLamaExecutor executor)
{
_executor = executor;
_history = new ChatHistory();
}

/// <summary>
/// Use a custom history transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithHistoryTransform(IHistoryTransform transform)
{
HistoryTransform = transform;
return this;
}

/// <summary>
/// Add a text transform to the input transform pipeline.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession AddInputTransform(ITextTransform transform)
{
InputTransformPipeline.Add(transform);
return this;
}

/// <summary>
/// Use a custom output transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithOutputTransform(ITextStreamTransform transform)
{
OutputTransform = transform;
return this;
}

/// <summary>
///
/// </summary>
/// <param name="path">The directory name to save the session. If the directory does not exist, a new directory will be created.</param>
public virtual void SaveSession(string path)
{
if (!Directory.Exists(path))
{
Directory.CreateDirectory(path);
}
_executor.Context.SaveState(Path.Combine(path, _modelStateFilename));
if (Executor is StatelessExecutor)
{
/// The executor for this session.
/// </summary>
public ILLamaExecutor Executor { get; private set; }

}
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.SaveState(Path.Combine(path, _executorStateFilename));
}
else
{
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
}
/// <summary>
/// The chat history for this session.
/// </summary>
public ChatHistory History { get; private set; } = new();

/// <summary>
/// The history transform used in this session.
/// </summary>
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();

/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public List<ITextTransform> InputTransformPipeline { get; set; } = new();

/// <summary>
/// The output transform used in this session.
/// </summary>
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
/// Create a new chat session.
/// </summary>
/// <param name="executor">The executor for this session</param>
public ChatSession(ILLamaExecutor executor)
{
// Check if executor has StatefulExecutorBase as base class
if (executor is not StatefulExecutorBase)
{
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
}

/// <summary>
///
/// </summary>
/// <param name="path">The directory name to load the session.</param>
public virtual void LoadSession(string path)
Executor = executor;
}

/// <summary>
/// Create a new chat session with a custom history.
/// </summary>
/// <param name="executor"></param>
/// <param name="history"></param>
public ChatSession(ILLamaExecutor executor, ChatHistory history)
: this(executor)
{
History = history;
}

/// <summary>
/// Use a custom history transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithHistoryTransform(IHistoryTransform transform)
{
HistoryTransform = transform;
return this;
}

/// <summary>
/// Add a text transform to the input transform pipeline.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession AddInputTransform(ITextTransform transform)
{
InputTransformPipeline.Add(transform);
return this;
}

/// <summary>
/// Use a custom output transform.
/// </summary>
/// <param name="transform"></param>
/// <returns></returns>
public ChatSession WithOutputTransform(ITextStreamTransform transform)
{
OutputTransform = transform;
return this;
}

/// <summary>
/// Save a session from a directory.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void SaveSession(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
if (!Directory.Exists(path))
{
throw new FileNotFoundException($"Directory {path} does not exist.");
}
_executor.Context.LoadState(Path.Combine(path, _modelStateFilename));
if (Executor is StatelessExecutor)
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}

}
else if (Executor is StatefulExecutorBase statefulExecutor)
{
statefulExecutor.LoadState(Path.Combine(path, _executorStateFilename));
}
else
{
throw new System.NotImplementedException("You're using a customized executor. Please inherit ChatSession and rewrite the method.");
}
if (Directory.Exists(path))
{
Directory.Delete(path, recursive: true);
}

/// <summary>
/// Generates a response for a given user prompt and manages history state for the user.
/// This will always pass the whole history to the model. Don't pass a whole history
/// to this method as the user prompt will be appended to the history of the current session.
/// If more control is needed, use the other overload of this method that accepts a ChatHistory object.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns>Returns generated text of the assistant message.</returns>
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
Directory.CreateDirectory(path);

string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.SaveState(modelStateFilePath);

string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);

string historyFilepath = Path.Combine(path, _hsitoryFilename);
File.WriteAllText(historyFilepath, History.ToJson());
}

/// <summary>
/// Load a session from a directory.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void LoadSession(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}

if (!Directory.Exists(path))
{
throw new ArgumentException("Directory does not exist", nameof(path));
}

string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.LoadState(modelStateFilePath);

// TODO: need to be refactored.
if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun)
string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);

string historyFilepath = Path.Combine(path, _hsitoryFilename);
string historyJson = File.ReadAllText(historyFilepath);
History = ChatHistory.FromJson(historyJson)
?? throw new ArgumentException("History file is invalid", nameof(path));
}

/// <summary>
/// Add a message to the chat history.
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public ChatSession AddMessage(ChatHistory.Message message)
{
// If current message is a system message, only allow the history to be empty
if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
{
throw new ArgumentException("Cannot add a system message after another message", nameof(message));
}

// If current message is a user message, only allow the history to be empty,
// or the previous message to be a system message or assistant message.
if (message.AuthorRole == AuthorRole.User)
{
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
{
History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt));
var converted_prompt = HistoryTransform.HistoryToText(History);
// Avoid missing anti-prompt.
if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n"))
{
prompt = converted_prompt.Trim();
}
else
{
prompt = converted_prompt;
}
throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
}
else
}

// If the current message is an assistant message,
// the previous message must be a user message.
if (message.AuthorRole == AuthorRole.Assistant)
{
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
if (lastMessage is null
|| lastMessage.AuthorRole != AuthorRole.User)
{
History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
throw new ArgumentException("Assistant message must be preceeded with a user message", nameof(message));
}
}

History.AddMessage(message.AuthorRole, message.Content);
return this;
}

/// <summary>
/// Add a system message to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public ChatSession AddSystemMessage(string content)
=> AddMessage(new ChatHistory.Message(AuthorRole.System, content));

/// <summary>
/// Add an assistant message to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public ChatSession AddAssistantMessage(string content)
=> AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));

/// <summary>
/// Add a user message to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public ChatSession AddUserMessage(string content)
=> AddMessage(new ChatHistory.Message(AuthorRole.User, content));

StringBuilder sb = new();
/// <summary>
/// Remove the last message from the chat history.
/// </summary>
/// <returns></returns>
public ChatSession RemoveLastMessage()
{
History.Messages.RemoveAt(History.Messages.Count - 1);
return this;
}

/// <summary>
/// Replace a user message with a new message and remove all messages after the new message.
/// This is useful when the user wants to edit a message. And regenerate the response.
/// </summary>
/// <param name="oldMessage"></param>
/// <param name="newMessage"></param>
/// <returns></returns>
public ChatSession ReplaceUserMessage(
ChatHistory.Message oldMessage,
ChatHistory.Message newMessage)
{
if (oldMessage.AuthorRole != AuthorRole.User)
{
throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
}

await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
if (newMessage.AuthorRole != AuthorRole.User)
{
throw new ArgumentException("New message must be a user message", nameof(newMessage));
}

int index = History.Messages.IndexOf(oldMessage);
if (index == -1)
{
throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
}

History.Messages[index] = newMessage;

// Remove all message after the new message
History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);

return this;
}

/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="message"></param>
/// <param name="inferenceParams"></param>
/// <param name="applyInputTransformPipeline"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public async IAsyncEnumerable<string> ChatAsync(
ChatHistory.Message message,
bool applyInputTransformPipeline,
IInferenceParams? inferenceParams = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// The message must be a user message
if (message.AuthorRole != AuthorRole.User)
{
throw new ArgumentException("Message must be a user message", nameof(message));
}

// Apply input transform pipeline
if (applyInputTransformPipeline)
{
foreach (var inputTransform in InputTransformPipeline)
{
yield return result;
sb.Append(result);
message.Content = inputTransform.Transform(message.Content);
}
}

// Add the user's message to the history
AddUserMessage(message.Content);

// Prepare prompt variable
string prompt;

// Check if the session history was restored from a previous session
// or added as part of new chat session history.
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();

// If "IsPromptRun" is true, the session was newly started.
if (state.IsPromptRun)
{
// If the session history was added as part of new chat session history,
// convert the complete history includsing system message and manually added history
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
prompt = HistoryTransform.HistoryToText(History);
}
else
{
// If the session was restored from a previous session,
// convert only the current message to the prompt with the prompt template
// specified in the HistoryTransform class implementation that is provided.
ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
prompt = HistoryTransform.HistoryToText(singleMessageHistory);
}

string assistantMessage = string.Empty;

await foreach (
string textToken
in ChatAsyncInternal(
prompt,
inferenceParams,
cancellationToken))
{
assistantMessage += textToken;
yield return textToken;
}

// Add the assistant message to the history
AddAssistantMessage(assistantMessage);
}

/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="message"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IAsyncEnumerable<string> ChatAsync(
ChatHistory.Message message,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
return ChatAsync(
message,
applyInputTransformPipeline: true,
inferenceParams,
cancellationToken);
}

string assistantMessage = sb.ToString();
/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="history"></param>
/// <param name="applyInputTransformPipeline"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public IAsyncEnumerable<string> ChatAsync(
ChatHistory history,
bool applyInputTransformPipeline,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
?? throw new ArgumentException("History must contain at least one message", nameof(history));

// Remove end tokens from the assistant message
// if defined in inferenceParams.AntiPrompts.
// We only want the response that was generated and not tokens
// that are delimiting the beginning or end of the response.
if (inferenceParams?.AntiPrompts != null)
foreach (
ChatHistory.Message message
in history.Messages.Take(history.Messages.Count - 1))
{
// Apply input transform pipeline
if (applyInputTransformPipeline
&& message.AuthorRole == AuthorRole.User)
{
foreach (var stopToken in inferenceParams.AntiPrompts)
foreach (
var inputTransform
in InputTransformPipeline)
{
assistantMessage = assistantMessage.Replace(stopToken, "");
message.Content = inputTransform.Transform(message.Content);
}
}

History.Messages.Add(new ChatHistory.Message(AuthorRole.Assistant, assistantMessage));
AddMessage(message);
}

/// <summary>
/// Generates a response for a given chat history. This method does not manage history state for the user.
/// If you want to e.g. truncate the history of a session to fit into the model's context window,
/// use this method and pass the truncated history to it. If you don't need this control, use the other
/// overload of this method that accepts a user prompt instead.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns>Returns generated text of the assistant message.</returns>
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
return ChatAsync(
lastMessage,
applyInputTransformPipeline,
inferenceParams,
cancellationToken);
}

/// <summary>
/// Chat with the model.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IAsyncEnumerable<string> ChatAsync(
ChatHistory history,
IInferenceParams? inferenceParams = null,
CancellationToken cancellationToken = default)
{
return ChatAsync(
history,
applyInputTransformPipeline: true,
inferenceParams,
cancellationToken);
}

/// <summary>
/// Regenerate the last assistant message.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
InferenceParams? inferenceParams = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Make sure the last message is an assistant message (reponse from the LLM).
ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();

if (lastAssistantMessage is null
|| lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
{
if (history.Messages.Count == 0)
{
throw new ArgumentException("History must contain at least one message.");
}
throw new InvalidOperationException("Last message must be an assistant message");
}

string prompt;
if (_executor is InteractiveExecutor executor)
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
// Remove the last assistant message from the history.
RemoveLastMessage();

prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: history.Messages.Last().Content;
}
else
{
prompt = history.Messages.Last().Content;
}
// Get the last user message.
ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();

await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
}
if (lastUserMessage is null
|| lastUserMessage.AuthorRole != AuthorRole.User)
{
throw new InvalidOperationException("Last message must be a user message");
}

private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
// Remove the last user message from the history.
RemoveLastMessage();

// Regenerate the assistant message.
await foreach (
string textToken
in ChatAsync(
lastUserMessage,
applyInputTransformPipeline: false,
inferenceParams,
cancellationToken))
{
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
await foreach (var item in OutputTransform.TransformAsync(results).WithCancellation(cancellationToken))
{
yield return item;
}
yield return textToken;
}
}

private async IAsyncEnumerable<string> ChatAsyncInternal(
string prompt,
IInferenceParams? inferenceParams = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);

await foreach (
string textToken
in OutputTransform
.TransformAsync(results)
.WithCancellation(cancellationToken))
{
yield return textToken;
}
}
}
}

+ 34
- 6
LLama/Common/ChatHistory.cs View File

@@ -1,4 +1,7 @@
using System.Collections.Generic;
using System.IO;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLama.Common
{
@@ -43,11 +46,14 @@ namespace LLama.Common
/// <summary>
/// Role of the message author, e.g. user/assistant/system
/// </summary>
[JsonConverter(typeof(JsonStringEnumConverter))]
[JsonPropertyName("author_role")]
public AuthorRole AuthorRole { get; set; }

/// <summary>
/// Message content
/// </summary>
[JsonPropertyName("content")]
public string Content { get; set; }

/// <summary>
@@ -65,15 +71,14 @@ namespace LLama.Common
/// <summary>
/// List of messages in the chat
/// </summary>
public List<Message> Messages { get; }
[JsonPropertyName("messages")]
public List<Message> Messages { get; set; } = new();

/// <summary>
/// Create a new instance of the chat content class
/// </summary>
public ChatHistory()
{
this.Messages = new List<Message>();
}
[JsonConstructor]
public ChatHistory() { }

/// <summary>
/// Add a message to the chat history
@@ -84,6 +89,29 @@ namespace LLama.Common
{
this.Messages.Add(new Message(authorRole, content));
}
}

/// <summary>
/// Serialize the chat history to JSON
/// </summary>
/// <returns></returns>
public string ToJson()
{
return JsonSerializer.Serialize(
this,
new JsonSerializerOptions()
{
WriteIndented = true
});
}

/// <summary>
/// Deserialize a chat history from JSON
/// </summary>
/// <param name="json"></param>
/// <returns></returns>
public static ChatHistory? FromJson(string json)
{
return JsonSerializer.Deserialize<ChatHistory>(json);
}
}
}

+ 4
- 0
LLama/Common/InferenceParams.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Common
{
@@ -76,6 +77,9 @@ namespace LLama.Common

/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}

/// <summary>


+ 13
- 35
LLama/Common/ModelParams.cs View File

@@ -59,7 +59,6 @@ namespace LLama.Common
public bool EmbeddingMode { get; set; }

/// <inheritdoc />
[JsonConverter(typeof(TensorSplitsCollectionConverter))]
public TensorSplitsCollection TensorSplits { get; set; } = new();

/// <inheritdoc />
@@ -92,9 +91,20 @@ namespace LLama.Common
/// <inheritdoc />
public bool VocabOnly { get; set; }

/// <summary>
/// `Encoding` cannot be directly JSON serialized, instead store the name as a string which can
/// </summary>
[JsonPropertyName("Encoding")]
[JsonInclude]
private string EncodingName { get; set; } = Encoding.UTF8.WebName;

/// <inheritdoc />
[JsonConverter(typeof(EncodingConverter))]
public Encoding Encoding { get; set; } = Encoding.UTF8;
[JsonIgnore]
public Encoding Encoding
{
get => Encoding.GetEncoding(EncodingName);
set => EncodingName = value.WebName;
}

/// <summary>
///
@@ -112,36 +122,4 @@ namespace LLama.Common
ModelPath = "";
}
}

internal class EncodingConverter
: JsonConverter<Encoding>
{
public override Encoding? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var name = reader.GetString();
if (name == null)
return null;
return Encoding.GetEncoding(name);
}

public override void Write(Utf8JsonWriter writer, Encoding value, JsonSerializerOptions options)
{
writer.WriteStringValue(value.WebName);
}
}

internal class TensorSplitsCollectionConverter
: JsonConverter<TensorSplitsCollection>
{
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
return new TensorSplitsCollection(arr);
}

public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}

+ 12
- 0
LLama/LLamaContext.cs View File

@@ -10,6 +10,7 @@ using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
@@ -212,6 +213,17 @@ namespace LLama
}
}

/// <summary>
/// Sample a single token from this context, using the given sampling pipeline
/// </summary>
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}

/// <summary>
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>


+ 17
- 9
LLama/LLamaInstructExecutor.cs View File

@@ -210,16 +210,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);



+ 18
- 10
LLama/LLamaInteractExecutor.cs View File

@@ -189,16 +189,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);



+ 1
- 0
LLama/LLamaSharp.csproj View File

@@ -28,6 +28,7 @@
<Platforms>AnyCPU;x64;Arm64</Platforms>
<PackageId>LLamaSharp</PackageId>
<Configurations>Debug;Release;GPU</Configurations>
<GenerateAssemblyInfo>false</GenerateAssemblyInfo>
</PropertyGroup>

<PropertyGroup>


+ 19
- 10
LLama/LLamaStatelessExecutor.cs View File

@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
@@ -85,16 +86,24 @@ namespace LLama
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
}

// Decode this token into text
decoder.Add(id);


+ 34
- 5
LLama/Native/LLamaTokenDataArray.cs View File

@@ -46,14 +46,41 @@ namespace LLama.Native
return new LLamaTokenDataArray(candidates);
}

/// <summary>
/// Overwrite the logit values for all given tokens
/// </summary>
/// <param name="values">tuples of token and logit value to overwrite</param>
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
{
if (values.Length == 0)
return;

var dataSpan = data.Span;
foreach (var (token, value) in values)
{
for (var i = 0; i < data.Length; i++)
{
if (dataSpan[i].id == token)
{
dataSpan[i].logit = value;
break;
}
}
}
sorted = false;
}

#region sampling
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="grammar"></param>
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
{
if (grammar == null)
return;

using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
@@ -145,15 +172,17 @@ namespace LLama.Native
/// <param name="penalty_repeat"></param>
/// <param name="penalty_freq"></param>
/// <param name="penalty_present"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
using (var last_tokens_handle = last_tokens.Pin())
{
NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
fixed (int* last_tokens_handle = last_tokens)
{
NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
}
}
}
}


+ 128
- 0
LLama/Sampling/BaseSamplingPipeline.cs View File

@@ -0,0 +1,128 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
/// </summary>
public abstract class BaseSamplingPipeline
: ISamplingPipeline
{
private int _savedLogitsCount;
private (int index, float logit)[]? _savedLogits;

/// <inheritdoc/>
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[index];
_savedLogits[i] = (index, value);
}

// Process raw logits
ProcessLogits(ctx, logits, lastTokens);

// Automatically restore saved logit values after processing
RestoreProtectedTokens(logits);

// Convert logits into token candidates
var candidates = LLamaTokenDataArray.Create(logits);

// Process token data array
ProcessTokenDataArray(ctx, candidates, lastTokens);

// Choose the final value
return ChooseToken(ctx, candidates);
}
finally
{
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
}

#region protected tokens
/// <summary>
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
/// </summary>
/// <returns></returns>
protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx);

/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="logits"></param>
protected void RestoreProtectedTokens(Span<float> logits)
{
if (_savedLogits == null)
return;

// The array may be bigger than necessary, get a span of the valid bit
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);

// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[saved[i].index] = saved[i].logit;
}

/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="candidates"></param>
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
{
if (_savedLogits == null || _savedLogits.Length == 0)
return;

candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
}
#endregion

/// <summary>
/// Process the raw logit values
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);

/// <summary>
/// Process the LLamaTokenDataArray and select a single token
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);

/// <summary>
/// Choose the final token from the candidates
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <returns></returns>
protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);

/// <inheritdoc/>
public virtual void Reset()
{
}

/// <inheritdoc/>
public virtual void Dispose()
{
GC.SuppressFinalize(this);
}
}

+ 149
- 0
LLama/Sampling/DefaultSamplingPipeline.cs View File

@@ -0,0 +1,149 @@
using System;
using System.Collections.Generic;
using LLama.Extensions;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
/// </summary>
public sealed class DefaultSamplingPipeline
: BaseSamplingPipeline
{
/// <summary>
/// Bias values to add to certain logits
/// </summary>
public Dictionary<int, float> LogitBias { get; } = new();

/// <summary>
/// Grammar to constrain valid tokens
/// </summary>
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
/// </summary>
public float RepeatPenalty { get; set; } = 1.1f;

/// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
public float AlphaFrequency
{
get => _alphaFreq;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaFreq = value;
}
}
private float _alphaFreq = 0.1f;

/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
public float AlphaPresence
{
get => _alphaPresence;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaPresence = value;
}
}
private float _alphaPresence = 0.1f;

/// <summary>
/// Temperature to apply (higher temperature is more "creative")
/// </summary>
public float Temperature { get; set; } = 0.75f;

/// <summary>
/// Number of tokens to keep in TopK sampling
/// </summary>
public int TopK { get; set; }

/// <summary>
/// Z value for tail free sampling
/// </summary>
public float TailFreeZ { get; set; }

/// <summary>
/// P value for locally typical sampling
/// </summary>
public float TypicalP { get; set; }

/// <summary>
/// P value for TopP sampling
/// </summary>
public float TopP { get; set; } = 1f;

/// <summary>
/// P value for MinP sampling
/// </summary>
public float MinP { get; set; }

/// <summary>
/// Whether the newline value should be protected from being modified by logit bias and repeat penalty
/// </summary>
public bool PenalizeNewline { get; set; } = false;

private readonly int[] _newlineToken = new int[1];

/// <inheritdoc />
protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx)
{
if (PenalizeNewline)
return Array.Empty<int>();

_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
return _newlineToken;
}

/// <inheritdoc />
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
foreach (var (key, value) in LogitBias)
logits[key] += value;
}

/// <inheritdoc />
protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
{
// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);

// Restore protected tokens, so they are not affected by repetition penalties
RestoreProtectedTokens(candidates);

// Apply the normal llama.cpp pipeline
candidates.ApplyGrammar(ctx, Grammar);
candidates.TopK(ctx, TopK);
candidates.TailFree(ctx, TailFreeZ);
candidates.LocallyTypical(ctx, TypicalP);
candidates.TopP(ctx, TopP);
candidates.MinP(ctx, MinP);
candidates.Temperature(ctx, Temperature);
var id = candidates.SampleToken(ctx);

Grammar?.AcceptToken(ctx, id);
return id;
}

/// <inheritdoc />
protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
return candidates.SampleToken(ctx);
}
}

+ 61
- 0
LLama/Sampling/ISamplingPipeline.cs View File

@@ -0,0 +1,61 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;

namespace LLama.Sampling;

/// <summary>
/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
/// </summary>
public interface ISamplingPipeline
: IDisposable
{
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A span of tokens recently returned by the model</param>
/// <returns></returns>
int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);

/// <summary>
/// Reset all internal state of the sampling pipeline
/// </summary>
void Reset();
}

/// <summary>
/// Extensions methods for ISamplingPipeline
/// </summary>
public static class ISamplingPipelineExtensions
{
/// <summary>
/// Sample a single token from the given logits
/// </summary>
/// <param name="pipeline"></param>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span<float> logits, List<int> lastTokens)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(lastTokens);
return pipeline.Sample(ctx, logits, span);
#else
var copy = ArrayPool<int>.Shared.Rent(lastTokens.Count);
try
{
lastTokens.CopyTo(copy);
return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
}
finally
{
ArrayPool<int>.Shared.Return(copy);
}
#endif
}
}

Loading…
Cancel
Save