diff --git a/LLama.Web/Common/LLamaExecutorType.cs b/LLama.Web/Common/LLamaExecutorType.cs new file mode 100644 index 00000000..b0be310c --- /dev/null +++ b/LLama.Web/Common/LLamaExecutorType.cs @@ -0,0 +1,9 @@ +namespace LLama.Web.Common +{ + public enum LLamaExecutorType + { + Interactive = 0, + Instruct = 1, + Stateless = 2 + } +} diff --git a/LLama.Web/Common/ServiceResult.cs b/LLama.Web/Common/ServiceResult.cs new file mode 100644 index 00000000..709a6d3a --- /dev/null +++ b/LLama.Web/Common/ServiceResult.cs @@ -0,0 +1,41 @@ +namespace LLama.Web.Common +{ + public class ServiceResult : ServiceResult, IServiceResult + { + public T Value { get; set; } + } + + + public class ServiceResult + { + public string Error { get; set; } + + public bool HasError + { + get { return !string.IsNullOrEmpty(Error); } + } + + public static IServiceResult FromValue(T value) + { + return new ServiceResult + { + Value = value, + }; + } + + public static IServiceResult FromError(string error) + { + return new ServiceResult + { + Error = error, + }; + } + } + + public interface IServiceResult + { + T Value { get; set; } + string Error { get; set; } + bool HasError { get; } + } +} diff --git a/LLama.Web/Hubs/InteractiveHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs similarity index 62% rename from LLama.Web/Hubs/InteractiveHub.cs rename to LLama.Web/Hubs/SessionConnectionHub.cs index 2bbdfd96..080866c6 100644 --- a/LLama.Web/Hubs/InteractiveHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -2,60 +2,58 @@ using LLama.Web.Models; using LLama.Web.Services; using Microsoft.AspNetCore.SignalR; -using Microsoft.Extensions.Options; using System.Diagnostics; namespace LLama.Web.Hubs { - public class InteractiveHub : Hub + public class SessionConnectionHub : Hub { - private readonly LLamaOptions _options; - private readonly ILogger _logger; - private readonly IModelSessionService _modelSessionService; + private readonly ILogger _logger; + private readonly ConnectionSessionService _modelSessionService; - public InteractiveHub(ILogger logger, IOptions options, IModelSessionService modelSessionService) + public SessionConnectionHub(ILogger logger, ConnectionSessionService modelSessionService) { _logger = logger; - _options = options.Value; _modelSessionService = modelSessionService; } - public override async Task OnConnectedAsync() { - _logger.Log(LogLevel.Information, "OnConnectedAsync, Id: {0}", Context.ConnectionId); - await base.OnConnectedAsync(); + _logger.Log(LogLevel.Information, "[OnConnectedAsync], Id: {0}", Context.ConnectionId); + + // Notify client of successful connection await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Connected); + await base.OnConnectedAsync(); } public override async Task OnDisconnectedAsync(Exception? exception) { _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId); - await _modelSessionService.RemoveAsync(Context.ConnectionId); + + // Remove connections session on dissconnect + await _modelSessionService.RemoveAsync(Context.ConnectionId); await base.OnDisconnectedAsync(exception); } [HubMethodName("LoadModel")] - public async Task OnLoadModel(string modelName, string promptName, string parameterName) + public async Task OnLoadModel(LLamaExecutorType executorType, string modelName, string promptName, string parameterName) { _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}, Model: {1}, Prompt: {2}, Parameter: {3}", Context.ConnectionId, modelName, promptName, parameterName); + + // Remove existing connections session await _modelSessionService.RemoveAsync(Context.ConnectionId); - var modelOption = _options.Models.First(x => x.Name == modelName); - var promptOption = _options.Prompts.First(x => x.Name == promptName); - var parameterOption = _options.Parameters.First(x => x.Name == parameterName); - var interactiveExecutor = new InteractiveExecutor(new LLamaModel(modelOption)); - var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, interactiveExecutor, modelOption, promptOption, parameterOption); - if (modelSession is null) + // Create model session + var modelSessionResult = await _modelSessionService.CreateAsync(executorType, Context.ConnectionId, modelName, promptName, parameterName); + if (modelSessionResult.HasError) { - _logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId); - await Clients.Caller.OnError("No model has been loaded"); + await Clients.Caller.OnError(modelSessionResult.Error); return; - } - _logger.Log(LogLevel.Information, "[OnLoadModel] - New model session added, Connection: {0}", Context.ConnectionId); + + // Notify client await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Loaded); } @@ -63,16 +61,17 @@ namespace LLama.Web.Hubs [HubMethodName("SendPrompt")] public async Task OnSendPrompt(string prompt) { - var stopwatch = Stopwatch.GetTimestamp(); _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId); + + // Get connections session var modelSession = await _modelSessionService.GetAsync(Context.ConnectionId); if (modelSession is null) { - _logger.Log(LogLevel.Warning, "[OnSendPrompt] - No model has been loaded for this connection, Connection: {0}", Context.ConnectionId); await Clients.Caller.OnError("No model has been loaded"); return; } + // Create unique response id var responseId = Guid.NewGuid().ToString(); @@ -80,6 +79,7 @@ namespace LLama.Web.Hubs await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); // Send content of response + var stopwatch = Stopwatch.GetTimestamp(); await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) { await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); @@ -93,6 +93,6 @@ namespace LLama.Web.Hubs await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); } - + } } diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index 8eeea3f6..d6d42813 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -25,6 +25,11 @@ namespace LLama.Web.Models _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); } + public string ModelName + { + get { return _modelOptions.Name; } + } + public IAsyncEnumerable InferAsync(string message, CancellationTokenSource cancellationTokenSource) { _cancellationTokenSource = cancellationTokenSource; diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml b/LLama.Web/Pages/Executor/Instruct.cshtml new file mode 100644 index 00000000..9f8cb2d8 --- /dev/null +++ b/LLama.Web/Pages/Executor/Instruct.cshtml @@ -0,0 +1,96 @@ +@page +@model InstructModel +@{ + +} +@Html.AntiForgeryToken() +
+ +
+
+

Instruct

+
+ Hub: Disconnected +
+
+ +
+ Model + +
+ +
+ Parameters + +
+ +
+ Prompt + + +
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+
+
+ +
+
+ +@{ await Html.RenderPartialAsync("_ChatTemplates"); } + +@section Scripts { + + +} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.cs b/LLama.Web/Pages/Executor/Instruct.cshtml.cs new file mode 100644 index 00000000..18a58253 --- /dev/null +++ b/LLama.Web/Pages/Executor/Instruct.cshtml.cs @@ -0,0 +1,34 @@ +using LLama.Web.Common; +using LLama.Web.Models; +using LLama.Web.Services; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.RazorPages; +using Microsoft.Extensions.Options; + +namespace LLama.Web.Pages +{ + public class InstructModel : PageModel + { + private readonly ILogger _logger; + private readonly ConnectionSessionService _modelSessionService; + + public InstructModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) + { + _logger = logger; + Options = options.Value; + _modelSessionService = modelSessionService; + } + + public LLamaOptions Options { get; set; } + + public void OnGet() + { + } + + public async Task OnPostCancel(CancelModel model) + { + await _modelSessionService.CancelAsync(model.ConnectionId); + return new JsonResult(default); + } + } +} \ No newline at end of file diff --git a/LLama.Web/Pages/Interactive.cshtml.css b/LLama.Web/Pages/Executor/Instruct.cshtml.css similarity index 100% rename from LLama.Web/Pages/Interactive.cshtml.css rename to LLama.Web/Pages/Executor/Instruct.cshtml.css diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml b/LLama.Web/Pages/Executor/Interactive.cshtml new file mode 100644 index 00000000..916b59ca --- /dev/null +++ b/LLama.Web/Pages/Executor/Interactive.cshtml @@ -0,0 +1,96 @@ +@page +@model InteractiveModel +@{ + +} +@Html.AntiForgeryToken() +
+ +
+
+

Interactive

+
+ Hub: Disconnected +
+
+ +
+ Model + +
+ +
+ Parameters + +
+ +
+ Prompt + + +
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+
+
+ +
+
+ +@{ await Html.RenderPartialAsync("_ChatTemplates");} + +@section Scripts { + + +} \ No newline at end of file diff --git a/LLama.Web/Pages/Interactive.cshtml.cs b/LLama.Web/Pages/Executor/Interactive.cshtml.cs similarity index 84% rename from LLama.Web/Pages/Interactive.cshtml.cs rename to LLama.Web/Pages/Executor/Interactive.cshtml.cs index c209a5b5..7179a440 100644 --- a/LLama.Web/Pages/Interactive.cshtml.cs +++ b/LLama.Web/Pages/Executor/Interactive.cshtml.cs @@ -10,9 +10,9 @@ namespace LLama.Web.Pages public class InteractiveModel : PageModel { private readonly ILogger _logger; - private readonly IModelSessionService _modelSessionService; + private readonly ConnectionSessionService _modelSessionService; - public InteractiveModel(ILogger logger, IOptions options, IModelSessionService modelSessionService) + public InteractiveModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) { _logger = logger; Options = options.Value; diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.css b/LLama.Web/Pages/Executor/Interactive.cshtml.css new file mode 100644 index 00000000..ed9a1d59 --- /dev/null +++ b/LLama.Web/Pages/Executor/Interactive.cshtml.css @@ -0,0 +1,4 @@ +.section-content { + flex: 1; + overflow-y: scroll; +} diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml b/LLama.Web/Pages/Executor/Stateless.cshtml new file mode 100644 index 00000000..b5d8eea3 --- /dev/null +++ b/LLama.Web/Pages/Executor/Stateless.cshtml @@ -0,0 +1,97 @@ +@page +@model StatelessModel +@{ + +} +@Html.AntiForgeryToken() +
+ +
+
+

Stateless

+
+ Hub: Disconnected +
+
+ +
+ Model + +
+ +
+ Parameters + +
+ +
+ Prompt + + +
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+
+
+ +
+
+ +@{ await Html.RenderPartialAsync("_ChatTemplates"); } + + +@section Scripts { + + +} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.cs b/LLama.Web/Pages/Executor/Stateless.cshtml.cs new file mode 100644 index 00000000..f88c4b83 --- /dev/null +++ b/LLama.Web/Pages/Executor/Stateless.cshtml.cs @@ -0,0 +1,34 @@ +using LLama.Web.Common; +using LLama.Web.Models; +using LLama.Web.Services; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.RazorPages; +using Microsoft.Extensions.Options; + +namespace LLama.Web.Pages +{ + public class StatelessModel : PageModel + { + private readonly ILogger _logger; + private readonly ConnectionSessionService _modelSessionService; + + public StatelessModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) + { + _logger = logger; + Options = options.Value; + _modelSessionService = modelSessionService; + } + + public LLamaOptions Options { get; set; } + + public void OnGet() + { + } + + public async Task OnPostCancel(CancelModel model) + { + await _modelSessionService.CancelAsync(model.ConnectionId); + return new JsonResult(default); + } + } +} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.css b/LLama.Web/Pages/Executor/Stateless.cshtml.css new file mode 100644 index 00000000..ed9a1d59 --- /dev/null +++ b/LLama.Web/Pages/Executor/Stateless.cshtml.css @@ -0,0 +1,4 @@ +.section-content { + flex: 1; + overflow-y: scroll; +} diff --git a/LLama.Web/Pages/Interactive.cshtml b/LLama.Web/Pages/Interactive.cshtml deleted file mode 100644 index 5224839a..00000000 --- a/LLama.Web/Pages/Interactive.cshtml +++ /dev/null @@ -1,338 +0,0 @@ -@page -@model InteractiveModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Interactive

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
- -
-
- -
- -
-
- -
-
-
- -
-
-
-
- -
- -
- -
-
- -
-
- - -
-
- -
- -
-
- - -
- - - - - - - - - - - - -@section Scripts { - - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml new file mode 100644 index 00000000..15644012 --- /dev/null +++ b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml @@ -0,0 +1,60 @@ + + + + + + + + + \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index 01c65651..23132bfa 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -24,7 +24,13 @@ Home + + diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 303e9da0..6db653a1 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -20,7 +20,7 @@ namespace LLama.Web .BindConfiguration(nameof(LLamaOptions)); // Services DI - builder.Services.AddSingleton(); + builder.Services.AddSingleton(); var app = builder.Build(); @@ -41,7 +41,7 @@ namespace LLama.Web app.MapRazorPages(); - app.MapHub(nameof(InteractiveHub)); + app.MapHub(nameof(SessionConnectionHub)); app.Run(); } diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs new file mode 100644 index 00000000..6c266f14 --- /dev/null +++ b/LLama.Web/Services/ConnectionSessionService.cs @@ -0,0 +1,94 @@ +using LLama.Abstractions; +using LLama.Web.Common; +using LLama.Web.Models; +using Microsoft.Extensions.Options; +using System.Collections.Concurrent; +using System.Drawing; + +namespace LLama.Web.Services +{ + /// + /// Example Service for handling a model session for a websockets connection lifetime + /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc + /// + public class ConnectionSessionService : IModelSessionService + { + private readonly LLamaOptions _options; + private readonly ILogger _logger; + private readonly ConcurrentDictionary _modelSessions; + + public ConnectionSessionService(ILogger logger, IOptions options) + { + _logger = logger; + _options = options.Value; + _modelSessions = new ConcurrentDictionary(); + } + + public Task GetAsync(string connectionId) + { + _modelSessions.TryGetValue(connectionId, out var modelSession); + return Task.FromResult(modelSession); + } + + public Task> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName) + { + var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName); + if (modelOption is null) + return Task.FromResult(ServiceResult.FromError($"Model option '{modelName}' not found")); + + var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName); + if (promptOption is null) + return Task.FromResult(ServiceResult.FromError($"Prompt option '{promptName}' not found")); + + var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName); + if (parameterOption is null) + return Task.FromResult(ServiceResult.FromError($"Parameter option '{parameterName}' not found")); + + + //Max instance + var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name); + if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) + return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); + + // Create model + var llamaModel = new LLamaModel(modelOption); + + // Create executor + ILLamaExecutor executor = executorType switch + { + LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel), + LLamaExecutorType.Instruct => new InstructExecutor(llamaModel), + LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel), + _ => default + }; + + // Create session + var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption); + if (!_modelSessions.TryAdd(connectionId, modelSession)) + return Task.FromResult(ServiceResult.FromError("Failed to create model session")); + + return Task.FromResult(ServiceResult.FromValue(modelSession)); + } + + public Task RemoveAsync(string connectionId) + { + if (_modelSessions.TryRemove(connectionId, out var modelSession)) + { + modelSession.CancelInfer(); + modelSession.Dispose(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + + public Task CancelAsync(string connectionId) + { + if (_modelSessions.TryGetValue(connectionId, out var modelSession)) + { + modelSession.CancelInfer(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + } +} diff --git a/LLama.Web/Services/IModelSessionService.cs b/LLama.Web/Services/IModelSessionService.cs index 0642c9a3..4ee0d483 100644 --- a/LLama.Web/Services/IModelSessionService.cs +++ b/LLama.Web/Services/IModelSessionService.cs @@ -6,10 +6,10 @@ namespace LLama.Web.Services { public interface IModelSessionService { - Task GetAsync(string connectionId); - Task CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption); - Task RemoveAsync(string connectionId); - Task CancelAsync(string connectionId); + Task GetAsync(string sessionId); + Task> CreateAsync(LLamaExecutorType executorType, string sessionId, string modelName, string promptName, string parameterName); + Task RemoveAsync(string sessionId); + Task CancelAsync(string sessionId); } diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs deleted file mode 100644 index 51b47f6e..00000000 --- a/LLama.Web/Services/ModelSessionService.cs +++ /dev/null @@ -1,58 +0,0 @@ -using LLama.Abstractions; -using LLama.Web.Common; -using LLama.Web.Models; -using System.Collections.Concurrent; - -namespace LLama.Web.Services -{ - public class ModelSessionService : IModelSessionService - { - private readonly ILogger _logger; - private readonly ConcurrentDictionary _modelSessions; - - public ModelSessionService(ILogger logger) - { - _logger = logger; - _modelSessions = new ConcurrentDictionary(); - } - - public Task GetAsync(string connectionId) - { - _modelSessions.TryGetValue(connectionId, out var modelSession); - return Task.FromResult(modelSession); - } - - public Task CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption) - { - //TODO: Max instance etc - var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption); - if (!_modelSessions.TryAdd(connectionId, modelSession)) - { - _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId); - return Task.FromResult(default); - } - return Task.FromResult(modelSession); - } - - public Task RemoveAsync(string connectionId) - { - if (_modelSessions.TryRemove(connectionId, out var modelSession)) - { - _logger.Log(LogLevel.Information, "[RemoveAsync] - Removed model session, Connection: {0}", connectionId); - modelSession.Dispose(); - } - return Task.CompletedTask; - } - - public Task CancelAsync(string connectionId) - { - if (_modelSessions.TryGetValue(connectionId, out var modelSession)) - { - _logger.Log(LogLevel.Information, "[CancelAsync] - Canceled model session, Connection: {0}", connectionId); - modelSession.CancelInfer(); - } - return Task.CompletedTask; - } - - } -} diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 9070a173..9f340a9c 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -10,9 +10,9 @@ "Models": [ { "Name": "WizardLM-7B", + "MaxInstances": 2, "ModelPath": "D:\\Repositories\\AI\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", - "ContextSize": 2048, - "MaxInstances": 4 + "ContextSize": 2048 } ], "Parameters": [ @@ -22,6 +22,10 @@ } ], "Prompts": [ + { + "Name": "None", + "Prompt": "" + }, { "Name": "Alpaca", "Path": "D:\\Repositories\\AI\\Prompts\\alpaca.txt", diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js new file mode 100644 index 00000000..472b5971 --- /dev/null +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -0,0 +1,176 @@ +const createConnectionSessionChat = (LLamaExecutorType) => { + const outputErrorTemplate = $("#outputErrorTemplate").html(); + const outputInfoTemplate = $("#outputInfoTemplate").html(); + const outputUserTemplate = $("#outputUserTemplate").html(); + const outputBotTemplate = $("#outputBotTemplate").html(); + const sessionDetailsTemplate = $("#sessionDetailsTemplate").html(); + + let connectionId; + const connection = new signalR.HubConnectionBuilder().withUrl("/SessionConnectionHub").build(); + + const scrollContainer = $("#scroll-container"); + const outputContainer = $("#output-container"); + const chatInput = $("#input"); + + + const onStatus = (connection, status) => { + connectionId = connection; + if (status == Enums.SessionConnectionStatus.Connected) { + $("#socket").text("Connected").addClass("text-success"); + } + else if (status == Enums.SessionConnectionStatus.Loaded) { + enableControls(); + $("#session-details").html(Mustache.render(sessionDetailsTemplate, { model: getSelectedModel(), prompt: getSelectedPrompt(), parameter: getSelectedParameter() })); + onInfo(`New model session successfully started`) + } + } + + const onError = (error) => { + enableControls(); + outputContainer.append(Mustache.render(outputErrorTemplate, { text: error, date: getDateTime() })); + } + + const onInfo = (message) => { + outputContainer.append(Mustache.render(outputInfoTemplate, { text: message, date: getDateTime() })); + } + + let responseContent; + let responseContainer; + let responseFirstFragment; + + const onResponse = (response) => { + if (!response) + return; + + if (response.isFirst) { + outputContainer.append(Mustache.render(outputBotTemplate, response)); + responseContainer = $(`#${response.id}`); + responseContent = responseContainer.find(".content"); + responseFirstFragment = true; + scrollToBottom(true); + return; + } + + if (response.isLast) { + enableControls(); + responseContainer.find(".signature").append(response.content); + scrollToBottom(); + } + else { + if (responseFirstFragment) { + responseContent.empty(); + responseFirstFragment = false; + responseContainer.find(".date").append(getDateTime()); + } + responseContent.append(response.content); + scrollToBottom(); + } + } + + + const sendPrompt = async () => { + const text = chatInput.val(); + if (text) { + disableControls(); + outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); + await connection.invoke('SendPrompt', text); + chatInput.val(null); + scrollToBottom(true); + } + } + + const cancelPrompt = async () => { + await ajaxPostJsonAsync('?handler=Cancel', { connectionId: connectionId }); + } + + const loadModel = async () => { + const modelName = getSelectedModel(); + const promptName = getSelectedPrompt(); + const parameterName = getSelectedParameter(); + if (!modelName || !promptName || !parameterName) { + onError("Please select a valid Model, Parameter and Prompt"); + return; + } + + disableControls(); + await connection.invoke('LoadModel', LLamaExecutorType, modelName, promptName, parameterName); + } + + + const enableControls = () => { + $(".input-control").removeAttr("disabled"); + } + + + const disableControls = () => { + $(".input-control").attr("disabled", "disabled"); + } + + const clearOutput = () => { + outputContainer.empty(); + } + + const updatePrompt = () => { + const customPrompt = $("#PromptText"); + const selection = $("option:selected", "#Prompt"); + const selectedValue = selection.data("prompt"); + customPrompt.text(selectedValue); + } + + + const getSelectedModel = () => { + return $("option:selected", "#Model").val(); + } + + + const getSelectedParameter = () => { + return $("option:selected", "#Parameter").val(); + } + + + const getSelectedPrompt = () => { + return $("option:selected", "#Prompt").val(); + } + + + const getDateTime = () => { + const dateTime = new Date(); + return dateTime.toLocaleString(); + } + + + const scrollToBottom = (force) => { + const scrollTop = scrollContainer.scrollTop(); + const scrollHeight = scrollContainer[0].scrollHeight; + if (force) { + scrollContainer.scrollTop(scrollContainer[0].scrollHeight); + return; + } + if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) { + scrollContainer.scrollTop(scrollContainer[0].scrollHeight) + } + } + + + + // Map UI functions + $("#load").on("click", loadModel); + $("#send").on("click", sendPrompt); + $("#clear").on("click", clearOutput); + $("#cancel").on("click", cancelPrompt); + $("#Prompt").on("change", updatePrompt); + chatInput.on('keydown', function (event) { + if (event.key === 'Enter' && !event.shiftKey) { + event.preventDefault(); + sendPrompt(); + } + }); + + + + // Map signalr functions + connection.on("OnStatus", onStatus); + connection.on("OnError", onError); + connection.on("OnResponse", onResponse); + connection.start(); +} \ No newline at end of file diff --git a/LLama.Web/wwwroot/js/site.js b/LLama.Web/wwwroot/js/site.js index 1fc916eb..2f679669 100644 --- a/LLama.Web/wwwroot/js/site.js +++ b/LLama.Web/wwwroot/js/site.js @@ -40,7 +40,11 @@ const Enums = { Loaded: 4, Connected: 10 }), - + LLamaExecutorType: Object.freeze({ + Interactive: 0, + Instruct: 1, + Stateless: 2 + }), GetName: (enumType, enumKey) => { return Object.keys(enumType)[enumKey] }, diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index dfa70edd..fe71a707 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -123,10 +123,12 @@ namespace LLama } /// - public async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken token = default) + public async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - yield return ""; - throw new NotImplementedException(); + foreach (var result in Infer(text, inferenceParams, cancellationToken)) + { + yield return result; + } } } }