| @@ -45,7 +45,8 @@ namespace LLama.Web.Hubs | |||
| var modelOption = _options.Models.First(x => x.Name == modelName); | |||
| var promptOption = _options.Prompts.First(x => x.Name == promptName); | |||
| var parameterOption = _options.Parameters.First(x => x.Name == parameterName); | |||
| var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, modelOption, promptOption, parameterOption); | |||
| var interactiveExecutor = new InteractiveExecutor(new LLamaModel(modelOption)); | |||
| var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, interactiveExecutor, modelOption, promptOption, parameterOption); | |||
| if (modelSession is null) | |||
| { | |||
| _logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId); | |||
| @@ -72,15 +73,15 @@ namespace LLama.Web.Hubs | |||
| } | |||
| // Create unique response id | |||
| modelSession.ResponseId = Guid.NewGuid().ToString(); | |||
| var responseId = Guid.NewGuid().ToString(); | |||
| // Send begin of response | |||
| await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, isFirst: true)); | |||
| await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); | |||
| // Send content of response | |||
| await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) | |||
| { | |||
| await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, fragment)); | |||
| await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); | |||
| } | |||
| // Send end of response | |||
| @@ -88,7 +89,7 @@ namespace LLama.Web.Hubs | |||
| var signature = modelSession.IsInferCanceled() | |||
| ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" | |||
| : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; | |||
| await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, signature, isLast: true)); | |||
| await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); | |||
| _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); | |||
| } | |||
| @@ -9,24 +9,23 @@ namespace LLama.Web.Models | |||
| private PromptOptions _promptOptions; | |||
| private ParameterOptions _inferenceOptions; | |||
| private ITextStreamTransform _outputTransform; | |||
| private InteractiveExecutor _interactiveExecutor; | |||
| private ILLamaExecutor _executor; | |||
| private CancellationTokenSource _cancellationTokenSource; | |||
| public ModelSession(string connectionId, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) | |||
| public ModelSession(string connectionId, ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) | |||
| { | |||
| ConnectionId = connectionId; | |||
| _executor = executor; | |||
| _modelOptions = modelOptions; | |||
| _promptOptions = promptOptions; | |||
| _inferenceOptions = parameterOptions; | |||
| _interactiveExecutor = new InteractiveExecutor(new LLamaModel(_modelOptions)); | |||
| _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty<string>()).Distinct() ?? _inferenceOptions.AntiPrompts; | |||
| if (_promptOptions.OutputFilter?.Count > 0) | |||
| _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); | |||
| } | |||
| public string ConnectionId { get; } | |||
| public string ResponseId { get; set; } | |||
| public IAsyncEnumerable<string> InferAsync(string message, CancellationTokenSource cancellationTokenSource) | |||
| { | |||
| @@ -38,9 +37,9 @@ namespace LLama.Web.Models | |||
| } | |||
| if (_outputTransform is not null) | |||
| return _outputTransform.TransformAsync(_interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); | |||
| return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); | |||
| return _interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); | |||
| return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); | |||
| } | |||
| @@ -58,8 +57,8 @@ namespace LLama.Web.Models | |||
| { | |||
| _inferenceOptions = null; | |||
| _outputTransform = null; | |||
| _interactiveExecutor.Model?.Dispose(); | |||
| _interactiveExecutor = null; | |||
| _executor.Model?.Dispose(); | |||
| _executor = null; | |||
| } | |||
| } | |||
| } | |||
| @@ -205,7 +205,7 @@ | |||
| responseContainer = $(`#${response.id}`); | |||
| responseContent = responseContainer.find(".content"); | |||
| responseFirstFragment = true; | |||
| scrollToBottom(); | |||
| scrollToBottom(true); | |||
| return; | |||
| } | |||
| @@ -233,6 +233,7 @@ | |||
| outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); | |||
| connection.invoke('SendPrompt', text); | |||
| $("#input").val(null); | |||
| scrollToBottom(true); | |||
| } | |||
| } | |||
| @@ -309,9 +310,13 @@ | |||
| } | |||
| const scrollToBottom = () => { | |||
| const scrollToBottom = (force) => { | |||
| const scrollTop = scrollContainer.scrollTop(); | |||
| const scrollHeight = scrollContainer[0].scrollHeight; | |||
| if(force){ | |||
| scrollContainer.scrollTop(scrollContainer[0].scrollHeight); | |||
| return; | |||
| } | |||
| if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) { | |||
| scrollContainer.scrollTop(scrollContainer[0].scrollHeight) | |||
| } | |||
| @@ -1,11 +1,12 @@ | |||
| using LLama.Web.Models; | |||
| using LLama.Abstractions; | |||
| using LLama.Web.Models; | |||
| namespace LLama.Web.Services | |||
| { | |||
| public interface IModelSessionService | |||
| { | |||
| Task<ModelSession> GetAsync(string connectionId); | |||
| Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption); | |||
| Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption); | |||
| Task RemoveAsync(string connectionId); | |||
| Task CancelAsync(string connectionId); | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using LLama.Web.Models; | |||
| using LLama.Abstractions; | |||
| using LLama.Web.Models; | |||
| using System.Collections.Concurrent; | |||
| namespace LLama.Web.Services | |||
| @@ -20,10 +21,10 @@ namespace LLama.Web.Services | |||
| return Task.FromResult(modelSession); | |||
| } | |||
| public Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption) | |||
| public Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption) | |||
| { | |||
| //TODO: Max instance etc | |||
| var modelSession = new ModelSession(connectionId, modelOption, promptOption, parameterOption); | |||
| var modelSession = new ModelSession( connectionId, executor, modelOption, promptOption, parameterOption); | |||
| if (!_modelSessions.TryAdd(connectionId, modelSession)) | |||
| { | |||
| _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId); | |||