Browse Source

Minimal changes required to remove non-async inference.

tags/v0.6.0
Martin Evans 2 years ago
parent
commit
3f80190f85
20 changed files with 108 additions and 181 deletions
  1. +2
    -2
      LLama.Examples/NewVersion/ChatSessionStripRoleName.cs
  2. +2
    -2
      LLama.Examples/NewVersion/ChatSessionWithRoleName.cs
  3. +3
    -3
      LLama.Examples/NewVersion/GrammarJsonResponse.cs
  4. +2
    -2
      LLama.Examples/NewVersion/InstructModeExecute.cs
  5. +2
    -2
      LLama.Examples/NewVersion/LoadAndSaveSession.cs
  6. +2
    -2
      LLama.Examples/NewVersion/LoadAndSaveState.cs
  7. +2
    -2
      LLama.Examples/NewVersion/StatelessModeExecute.cs
  8. +7
    -7
      LLama.Examples/NewVersion/TestRunner.cs
  9. +2
    -2
      LLama.Unittest/GrammarTest.cs
  10. +1
    -0
      LLama.Unittest/LLama.Unittest.csproj
  11. +6
    -6
      LLama.Unittest/StatelessExecutorTest.cs
  12. +1
    -1
      LLama.WebAPI/Controllers/ChatController.cs
  13. +3
    -3
      LLama.WebAPI/Services/StatefulChatService.cs
  14. +0
    -9
      LLama/Abstractions/ILLamaExecutor.cs
  15. +0
    -7
      LLama/Abstractions/ITextStreamTransform.cs
  16. +3
    -53
      LLama/ChatSession.cs
  17. +22
    -30
      LLama/LLamaExecutorBase.cs
  18. +24
    -17
      LLama/LLamaInstructExecutor.cs
  19. +23
    -21
      LLama/LLamaInteractExecutor.cs
  20. +1
    -10
      LLama/LLamaStatelessExecutor.cs

+ 2
- 2
LLama.Examples/NewVersion/ChatSessionStripRoleName.cs View File

@@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class ChatSessionStripRoleName
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}


+ 2
- 2
LLama.Examples/NewVersion/ChatSessionWithRoleName.cs View File

@@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class ChatSessionWithRoleName
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}


+ 3
- 3
LLama.Examples/NewVersion/GrammarJsonResponse.cs View File

@@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion
{
public class GrammarJsonResponse
{
public static void Run()
public static async Task Run()
{
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
var grammar = Grammar.Parse(gbnf, "root");

Console.Write("Please input your model path: ");
@@ -43,7 +43,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}


+ 2
- 2
LLama.Examples/NewVersion/InstructModeExecute.cs View File

@@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class InstructModeExecute
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion

while (true)
{
foreach (var text in executor.Infer(prompt, inferenceParams))
await foreach (var text in executor.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}


+ 2
- 2
LLama.Examples/NewVersion/LoadAndSaveSession.cs View File

@@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class SaveAndLoadSession
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}


+ 2
- 2
LLama.Examples/NewVersion/LoadAndSaveState.cs View File

@@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class LoadAndSaveState
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion

while (true)
{
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}


+ 2
- 2
LLama.Examples/NewVersion/StatelessModeExecute.cs View File

@@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class StatelessModeExecute
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@@ -35,7 +35,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}


+ 7
- 7
LLama.Examples/NewVersion/TestRunner.cs View File

@@ -29,11 +29,11 @@

if (choice == 0)
{
ChatSessionWithRoleName.Run();
await ChatSessionWithRoleName.Run();
}
else if (choice == 1)
{
ChatSessionStripRoleName.Run();
await ChatSessionStripRoleName.Run();
}
else if(choice == 2)
{
@@ -41,19 +41,19 @@
}
else if(choice == 3)
{
InstructModeExecute.Run();
await InstructModeExecute.Run();
}
else if(choice == 4)
{
StatelessModeExecute.Run();
await StatelessModeExecute.Run();
}
else if(choice == 5)
{
SaveAndLoadSession.Run();
await SaveAndLoadSession.Run();
}
else if(choice == 6)
{
LoadAndSaveState.Run();
await LoadAndSaveState.Run();
}
else if(choice == 7)
{
@@ -69,7 +69,7 @@
}
else if (choice == 10)
{
GrammarJsonResponse.Run();
await GrammarJsonResponse.Run();
}
else if (choice == 11)
{


+ 2
- 2
LLama.Unittest/GrammarTest.cs View File

@@ -41,7 +41,7 @@ namespace LLama.Unittest
}

[Fact]
public void SampleWithTrivialGrammar()
public async Task SampleWithTrivialGrammar()
{
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar!
@@ -66,7 +66,7 @@ namespace LLama.Unittest
Grammar = grammar,
};

var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList();
var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();

Assert.Equal("cat", result[0]);
}


+ 1
- 0
LLama.Unittest/LLama.Unittest.csproj View File

@@ -12,6 +12,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>


+ 6
- 6
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -27,15 +27,15 @@ namespace LLama.Unittest
}

[Fact]
public void Stateless()
public async Task Stateless()
{
var executor = new StatelessExecutor(_weights, _params);

const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };

var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());

_testOutputHelper.WriteLine(result1);

@@ -44,7 +44,7 @@ namespace LLama.Unittest
}

[Fact]
public void OutOfContext()
public async Task OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);

@@ -58,8 +58,8 @@ namespace LLama.Unittest
TokensKeep = question.Length,
};

var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());

_testOutputHelper.WriteLine(result1);



+ 1
- 1
LLama.WebAPI/Controllers/ChatController.cs View File

@@ -18,7 +18,7 @@ namespace LLama.WebAPI.Controllers
}

[HttpPost("Send")]
public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
public Task<string> SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
{
return _service.Send(input);
}


+ 3
- 3
LLama.WebAPI/Services/StatefulChatService.cs View File

@@ -28,7 +28,7 @@ public class StatefulChatService : IDisposable
_context?.Dispose();
}

public string Send(SendMessageInput input)
public async Task<string> Send(SendMessageInput input)
{
var userInput = input.Text;
if (!_continue)
@@ -42,13 +42,13 @@ public class StatefulChatService : IDisposable
Console.Write(input.Text);

Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.Chat(userInput, new Common.InferenceParams()
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});
var result = "";
foreach (var output in outputs)
await foreach (var output in outputs)
{
Console.Write(output);
result += output;


+ 0
- 9
LLama/Abstractions/ILLamaExecutor.cs View File

@@ -13,15 +13,6 @@ namespace LLama.Abstractions
/// </summary>
public LLamaContext Context { get; }

/// <summary>
/// Infers a response from the model.
/// </summary>
/// <param name="text">Your prompt</param>
/// <param name="inferenceParams">Any additional parameters</param>
/// <param name="token">A cancellation token.</param>
/// <returns></returns>
IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);

/// <summary>
/// Asynchronously infers a response from the model.
/// </summary>


+ 0
- 7
LLama/Abstractions/ITextStreamTransform.cs View File

@@ -7,13 +7,6 @@ namespace LLama.Abstractions
/// </summary>
public interface ITextStreamTransform
{
/// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens.
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
IEnumerable<string> Transform(IEnumerable<string> tokens);

/// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously.
/// </summary>


+ 3
- 53
LLama/ChatSession.cs View File

@@ -134,26 +134,6 @@ namespace LLama
}
}

/// <summary>
/// Get the response from the LLama model with chat histories.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

/// <summary>
/// Get the response from the LLama model. Note that prompt could not only be the preset words,
/// but also the question you want to ask.
@@ -162,15 +142,14 @@ namespace LLama
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach(var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
@@ -198,35 +177,6 @@ namespace LLama
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

/// <summary>
/// Get the response from the LLama model with chat histories asynchronously.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
return OutputTransform.Transform(results);
}

private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);


+ 22
- 30
LLama/LLamaExecutorBase.cs View File

@@ -10,6 +10,7 @@ using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;

namespace LLama
{
@@ -212,47 +213,53 @@ namespace LLama
/// </summary>
/// <param name="args"></param>
/// <returns></returns>
protected abstract bool GetLoopCondition(InferStateArgs args);
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);

/// <summary>
/// Preprocess the inputs before the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="args"></param>
protected abstract void PreprocessInputs(string text, InferStateArgs args);
protected abstract Task PreprocessInputs(string text, InferStateArgs args);

/// <summary>
/// Do some post processing after the inference.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <param name="extraOutputs"></param>
/// <returns></returns>
protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs);
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);

/// <summary>
/// The core inference logic.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);

/// <summary>
/// Save the current state to a file.
/// </summary>
/// <param name="filename"></param>
public abstract void SaveState(string filename);
public abstract Task SaveState(string filename);

/// <summary>
/// Get the current state data.
/// </summary>
/// <returns></returns>
public abstract ExecutorBaseState GetStateData();

/// <summary>
/// Load the state from data.
/// </summary>
/// <param name="data"></param>
public abstract void LoadState(ExecutorBaseState data);
public abstract Task LoadState(ExecutorBaseState data);

/// <summary>
/// Load the state from a file.
/// </summary>
/// <param name="filename"></param>
public abstract void LoadState(string filename);
public abstract Task LoadState(string filename);


/// <summary>
@@ -262,12 +269,12 @@ namespace LLama
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();

InferStateArgs args = new InferStateArgs()
var args = new InferStateArgs
{
Antiprompts = inferenceParams.AntiPrompts.ToList(),
RemainedTokens = inferenceParams.MaxTokens,
@@ -276,15 +283,15 @@ namespace LLama
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};

PreprocessInputs(text, args);
await PreprocessInputs(text, args);

while (GetLoopCondition(args))
while (await GetLoopCondition(args))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}
InferInternal(inferenceParams, args);
await InferInternal(inferenceParams, args);

if (args.ReturnValue)
{
@@ -292,8 +299,8 @@ namespace LLama
yield return Context.TokenToString(id);
}

var breakGeneration = PostProcess(inferenceParams, args, out var extraOutputs);
if (extraOutputs is not null)
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
{
@@ -307,21 +314,6 @@ namespace LLama
}
}

/// <summary>
/// Execute the inference asynchronously.
/// </summary>
/// <param name="text"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}
}

/// <summary>
/// State arguments that are used in single inference
/// </summary>


+ 24
- 17
LLama/LLamaInstructExecutor.cs View File

@@ -5,9 +5,9 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using LLama.Extensions;

namespace LLama
@@ -60,7 +60,7 @@ namespace LLama
return state;
}
/// <inheritdoc />
public override void LoadState(ExecutorBaseState data)
public override Task LoadState(ExecutorBaseState data)
{
if(data is InstructExecutorState state)
{
@@ -81,34 +81,37 @@ namespace LLama
{
throw new ArgumentException("Invalid state data type.");
}

return Task.CompletedTask;
}

/// <inheritdoc />
public override void SaveState(string filename)
public override async Task SaveState(string filename)
{
var state = (InstructExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
JsonSerializer.Serialize(fs, state);
await JsonSerializer.SerializeAsync(fs, state);
}
}
/// <inheritdoc />
public override void LoadState(string filename)
public override async Task LoadState(string filename)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = JsonSerializer.Deserialize<InstructExecutorState>(fs);
LoadState(state);
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);
}
}

/// <inheritdoc />
protected override bool GetLoopCondition(InferStateArgs args)
protected override Task<bool> GetLoopCondition(InferStateArgs args)
{
return args.RemainedTokens != 0 || _is_prompt_run;
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
}

/// <inheritdoc />
protected override void PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string text, InferStateArgs args)
{
args.Antiprompts ??= new List<string>();
args.Antiprompts.Add(_instructionPrefix);
@@ -133,23 +136,24 @@ namespace LLama

args.RemainedTokens -= line_inp.Length;
}

return Task.CompletedTask;
}

/// <inheritdoc />
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{
args.WaitForInput = true;
return true;
return (true, Array.Empty<string>());
}

if (_pastTokensCount > 0 && args.WaitForInput)
{
extraOutputs = new[] { "\n> " };
return true;
return (true, new[] { "\n> " });
}
}

@@ -163,10 +167,11 @@ namespace LLama
args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true;
}
return false;
return (false, Array.Empty<string>());
}

/// <inheritdoc />
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{
@@ -230,6 +235,8 @@ namespace LLama
}
}
}

return Task.CompletedTask;
}
/// <summary>
/// The desciptor of the state of the instruct executor.


+ 23
- 21
LLama/LLamaInteractExecutor.cs View File

@@ -7,7 +7,7 @@ using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text;
using System.Threading.Tasks;
using LLama.Extensions;

namespace LLama
@@ -51,7 +51,7 @@ namespace LLama
return state;
}
/// <inheritdoc />
public override void LoadState(ExecutorBaseState data)
public override Task LoadState(ExecutorBaseState data)
{
if (data is InteractiveExecutorState state)
{
@@ -68,23 +68,25 @@ namespace LLama
}
else
throw new ArgumentException("Invalid state data type.");

return Task.CompletedTask;
}
/// <inheritdoc />
public override void SaveState(string filename)
public override async Task SaveState(string filename)
{
InteractiveExecutorState state = (InteractiveExecutorState)GetStateData();
using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
var state = (InteractiveExecutorState)GetStateData();
using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
JsonSerializer.Serialize(fs, state);
await JsonSerializer.SerializeAsync(fs, state);
}
}
/// <inheritdoc />
public override void LoadState(string filename)
public override async Task LoadState(string filename)
{
using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs);
LoadState(state);
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);
}
}

@@ -92,13 +94,13 @@ namespace LLama
/// Define whether to continue the loop to generate responses.
/// </summary>
/// <returns></returns>
protected override bool GetLoopCondition(InferStateArgs args)
protected override Task<bool> GetLoopCondition(InferStateArgs args)
{
return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run;
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
}

/// <inheritdoc />
protected override void PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string text, InferStateArgs args)
{
if (_is_prompt_run)
{
@@ -115,6 +117,8 @@ namespace LLama
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}

return Task.CompletedTask;
}

/// <summary>
@@ -122,24 +126,21 @@ namespace LLama
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <param name="extraOutputs"></param>
/// <returns></returns>
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
return true;
return (true, Array.Empty<string>());
}

if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
extraOutputs = new[] { " [end of text]\n" };
return true;
return (true, new[] { " [end of text]\n" });
}

if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
@@ -147,11 +148,12 @@ namespace LLama
args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true;
}
return false;

return (false, Array.Empty<string>());
}

/// <inheritdoc />
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{


+ 1
- 10
LLama/LLamaStatelessExecutor.cs View File

@@ -55,7 +55,7 @@ namespace LLama
}

/// <inheritdoc />
public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var context = _weights.CreateContext(_params);
Context = context;
@@ -140,14 +140,5 @@ namespace LLama
{
return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding);
}

/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}
}
}
}

Loading…
Cancel
Save