Browse Source

Move session management to service, Use ILLamaExecutor in session to make more versatile, scroll bug

tags/v0.4.2-preview
sa_ddam213 2 years ago
parent
commit
a139423581
5 changed files with 27 additions and 20 deletions
  1. +6
    -5
      LLama.Web/Hubs/InteractiveHub.cs
  2. +7
    -8
      LLama.Web/Models/ModelSession.cs
  3. +7
    -2
      LLama.Web/Pages/Interactive.cshtml
  4. +3
    -2
      LLama.Web/Services/IModelSessionService.cs
  5. +4
    -3
      LLama.Web/Services/ModelSessionService.cs

+ 6
- 5
LLama.Web/Hubs/InteractiveHub.cs View File

@@ -45,7 +45,8 @@ namespace LLama.Web.Hubs
var modelOption = _options.Models.First(x => x.Name == modelName);
var promptOption = _options.Prompts.First(x => x.Name == promptName);
var parameterOption = _options.Parameters.First(x => x.Name == parameterName);
var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, modelOption, promptOption, parameterOption);
var interactiveExecutor = new InteractiveExecutor(new LLamaModel(modelOption));
var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, interactiveExecutor, modelOption, promptOption, parameterOption);
if (modelSession is null)
{
_logger.Log(LogLevel.Error, "[OnLoadModel] - Failed to add new model session, Connection: {0}", Context.ConnectionId);
@@ -72,15 +73,15 @@ namespace LLama.Web.Hubs
}

// Create unique response id
modelSession.ResponseId = Guid.NewGuid().ToString();
var responseId = Guid.NewGuid().ToString();

// Send begin of response
await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, isFirst: true));
await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true));

// Send content of response
await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted)))
{
await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, fragment));
await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment));
}

// Send end of response
@@ -88,7 +89,7 @@ namespace LLama.Web.Hubs
var signature = modelSession.IsInferCanceled()
? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds"
: $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds";
await Clients.Caller.OnResponse(new ResponseFragment(modelSession.ResponseId, signature, isLast: true));
await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true));
_logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled());
}


+ 7
- 8
LLama.Web/Models/ModelSession.cs View File

@@ -9,24 +9,23 @@ namespace LLama.Web.Models
private PromptOptions _promptOptions;
private ParameterOptions _inferenceOptions;
private ITextStreamTransform _outputTransform;
private InteractiveExecutor _interactiveExecutor;
private ILLamaExecutor _executor;
private CancellationTokenSource _cancellationTokenSource;

public ModelSession(string connectionId, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions)
public ModelSession(string connectionId, ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions)
{
ConnectionId = connectionId;
_executor = executor;
_modelOptions = modelOptions;
_promptOptions = promptOptions;
_inferenceOptions = parameterOptions;
_interactiveExecutor = new InteractiveExecutor(new LLamaModel(_modelOptions));
_inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty<string>()).Distinct() ?? _inferenceOptions.AntiPrompts;
if (_promptOptions.OutputFilter?.Count > 0)
_outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5);
}

public string ConnectionId { get; }
public string ResponseId { get; set; }

public IAsyncEnumerable<string> InferAsync(string message, CancellationTokenSource cancellationTokenSource)
{
@@ -38,9 +37,9 @@ namespace LLama.Web.Models
}

if (_outputTransform is not null)
return _outputTransform.TransformAsync(_interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token));
return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token));

return _interactiveExecutor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token);
return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token);
}


@@ -58,8 +57,8 @@ namespace LLama.Web.Models
{
_inferenceOptions = null;
_outputTransform = null;
_interactiveExecutor.Model?.Dispose();
_interactiveExecutor = null;
_executor.Model?.Dispose();
_executor = null;
}
}
}

+ 7
- 2
LLama.Web/Pages/Interactive.cshtml View File

@@ -205,7 +205,7 @@
responseContainer = $(`#${response.id}`);
responseContent = responseContainer.find(".content");
responseFirstFragment = true;
scrollToBottom();
scrollToBottom(true);
return;
}

@@ -233,6 +233,7 @@
outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() }));
connection.invoke('SendPrompt', text);
$("#input").val(null);
scrollToBottom(true);
}
}

@@ -309,9 +310,13 @@
}


const scrollToBottom = () => {
const scrollToBottom = (force) => {
const scrollTop = scrollContainer.scrollTop();
const scrollHeight = scrollContainer[0].scrollHeight;
if(force){
scrollContainer.scrollTop(scrollContainer[0].scrollHeight);
return;
}
if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) {
scrollContainer.scrollTop(scrollContainer[0].scrollHeight)
}


+ 3
- 2
LLama.Web/Services/IModelSessionService.cs View File

@@ -1,11 +1,12 @@
using LLama.Web.Models;
using LLama.Abstractions;
using LLama.Web.Models;

namespace LLama.Web.Services
{
public interface IModelSessionService
{
Task<ModelSession> GetAsync(string connectionId);
Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption);
Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption);
Task RemoveAsync(string connectionId);
Task CancelAsync(string connectionId);
}


+ 4
- 3
LLama.Web/Services/ModelSessionService.cs View File

@@ -1,4 +1,5 @@
using LLama.Web.Models;
using LLama.Abstractions;
using LLama.Web.Models;
using System.Collections.Concurrent;

namespace LLama.Web.Services
@@ -20,10 +21,10 @@ namespace LLama.Web.Services
return Task.FromResult(modelSession);
}

public Task<ModelSession> CreateAsync(string connectionId, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption)
public Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption)
{
//TODO: Max instance etc
var modelSession = new ModelSession(connectionId, modelOption, promptOption, parameterOption);
var modelSession = new ModelSession( connectionId, executor, modelOption, promptOption, parameterOption);
if (!_modelSessions.TryAdd(connectionId, modelSession))
{
_logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId);


Loading…
Cancel
Save