From a139423581e3fcd5baa0cb577c9abd8a11e7cc0b Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 19 Jul 2023 08:55:56 +1200 Subject: [PATCH] Move session management to service, Use ILLamaExecutor in session to make more versatile, scroll bug --- LLama.Web/Hubs/InteractiveHub.cs | 11 ++++++----- LLama.Web/Models/ModelSession.cs | 15 +++++++-------- LLama.Web/Pages/Interactive.cshtml | 9 +++++++-- LLama.Web/Services/IModelSessionService.cs | 5 +++-- LLama.Web/Services/ModelSessionService.cs | 7 ++++--- 5 files changed, 27 insertions(+), 20 deletions(-) diff --git a/LLama.Web/Hubs/InteractiveHub.cs b/LLama.Web/Hubs/InteractiveHub.cs index 77879a4f..502904ad 100644 --- a/LLama.Web/Hubs/InteractiveHub.cs +++ b/LLama.Web/Hubs/InteractiveHub.cs @@ -45,7 +45,8 @@ namespace LLama.Web.Hubs var modelOption = _options.Models.First(x => x.Name == modelName); var promptOption = _options.Prompts.First(x => x.Name == promptName); var parameterOption = _options.Parameters.First(x => x.Name == parameterName); - var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, modelOption, promptOption, parameterOption); + var interactiveExecutor = new InteractiveExecutor(new LLamaModel(modelOption)); + var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, interactiveExecutor, modelOption, promptOption, parameterOption); if (modelSession is null) { _logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId); @@ -72,15 +73,15 @@ namespace LLama.Web.Hubs } // Create unique response id - modelSession.ResponseId = Guid.NewGuid().ToString(); + var responseId = Guid.NewGuid().ToString(); // Send begin of response - await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, isFirst: true)); + await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); // Send content of response await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) { - await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, fragment)); + await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); } // Send end of response @@ -88,7 +89,7 @@ namespace LLama.Web.Hubs var signature = modelSession.IsInferCanceled() ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; - await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, signature, isLast: true)); + await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); } diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index d5dfbd07..00c32f50 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -9,24 +9,23 @@ namespace LLama.Web.Models private PromptOptions _promptOptions; private ParameterOptions _inferenceOptions; private ITextStreamTransform _outputTransform; - private InteractiveExecutor _interactiveExecutor; + private ILLamaExecutor _executor; private CancellationTokenSource _cancellationTokenSource; - public ModelSession(string connectionId, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) + public ModelSession(string connectionId, ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) { ConnectionId = connectionId; + _executor = executor; _modelOptions = modelOptions; _promptOptions = promptOptions; _inferenceOptions = parameterOptions; - _interactiveExecutor = new InteractiveExecutor(new LLamaModel(_modelOptions)); _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty()).Distinct() ?? _inferenceOptions.AntiPrompts; if (_promptOptions.OutputFilter?.Count > 0) _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); } public string ConnectionId { get; } - public string ResponseId { get; set; } public IAsyncEnumerable InferAsync(string message, CancellationTokenSource cancellationTokenSource) { @@ -38,9 +37,9 @@ namespace LLama.Web.Models } if (_outputTransform is not null) - return _outputTransform.TransformAsync(_interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); + return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); - return _interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); + return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); } @@ -58,8 +57,8 @@ namespace LLama.Web.Models { _inferenceOptions = null; _outputTransform = null; - _interactiveExecutor.Model?.Dispose(); - _interactiveExecutor = null; + _executor.Model?.Dispose(); + _executor = null; } } } diff --git a/LLama.Web/Pages/Interactive.cshtml b/LLama.Web/Pages/Interactive.cshtml index 3b1df234..5f39753c 100644 --- a/LLama.Web/Pages/Interactive.cshtml +++ b/LLama.Web/Pages/Interactive.cshtml @@ -205,7 +205,7 @@ responseContainer = $(`#${response.id}`); responseContent = responseContainer.find(".content"); responseFirstFragment = true; - scrollToBottom(); + scrollToBottom(true); return; } @@ -233,6 +233,7 @@ outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); connection.invoke('SendPrompt', text); $("#input").val(null); + scrollToBottom(true); } } @@ -309,9 +310,13 @@ } - const scrollToBottom = () => { + const scrollToBottom = (force) => { const scrollTop = scrollContainer.scrollTop(); const scrollHeight = scrollContainer[0].scrollHeight; + if(force){ + scrollContainer.scrollTop(scrollContainer[0].scrollHeight); + return; + } if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) { scrollContainer.scrollTop(scrollContainer[0].scrollHeight) } diff --git a/LLama.Web/Services/IModelSessionService.cs b/LLama.Web/Services/IModelSessionService.cs index e62c9c06..33c797e6 100644 --- a/LLama.Web/Services/IModelSessionService.cs +++ b/LLama.Web/Services/IModelSessionService.cs @@ -1,11 +1,12 @@ -using LLama.Web.Models; +using LLama.Abstractions; +using LLama.Web.Models; namespace LLama.Web.Services { public interface IModelSessionService { Task GetAsync(string connectionId); - Task CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption); + Task CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption); Task RemoveAsync(string connectionId); Task CancelAsync(string connectionId); } diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs index 17720551..5de0f316 100644 --- a/LLama.Web/Services/ModelSessionService.cs +++ b/LLama.Web/Services/ModelSessionService.cs @@ -1,4 +1,5 @@ -using LLama.Web.Models; +using LLama.Abstractions; +using LLama.Web.Models; using System.Collections.Concurrent; namespace LLama.Web.Services @@ -20,10 +21,10 @@ namespace LLama.Web.Services return Task.FromResult(modelSession); } - public Task CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption) + public Task CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption) { //TODO: Max instance etc - var modelSession = new ModelSession(connectionId, modelOption, promptOption, parameterOption); + var modelSession = new ModelSession( connectionId, executor, modelOption, promptOption, parameterOption); if (!_modelSessions.TryAdd(connectionId, modelSession)) { _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId);