using LLama.Web.Async;
using LLama.Web.Common;
using LLama.Web.Models;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Runtime.CompilerServices;
namespace LLama.Web.Services
{
///
/// Example Service for handling a model session for a websockets connection lifetime
/// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc
///
public class ModelSessionService : IModelSessionService
{
private readonly AsyncGuard _sessionGuard;
private readonly IModelService _modelService;
private readonly ConcurrentDictionary _modelSessions;
///
/// Initializes a new instance of the class.
///
/// The model service.
/// The model session state service.
public ModelSessionService(IModelService modelService)
{
_modelService = modelService;
_sessionGuard = new AsyncGuard();
_modelSessions = new ConcurrentDictionary();
}
///
/// Gets the ModelSession with the specified Id.
///
/// The session identifier.
/// The ModelSession if exists, otherwise null
public Task GetAsync(string sessionId)
{
return Task.FromResult(_modelSessions.TryGetValue(sessionId, out var session) ? session : null);
}
///
/// Gets all ModelSessions
///
/// A collection oa all Model instances
public Task> GetAllAsync()
{
return Task.FromResult>(_modelSessions.Values);
}
///
/// Creates a new ModelSession
///
/// The session identifier.
/// The session configuration.
/// The default inference configuration, will be used for all inference where no infer configuration is supplied.
/// The cancellation token.
///
///
/// Session with id {sessionId} already exists
/// or
/// Failed to create model session
///
public async Task CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
if (_modelSessions.TryGetValue(sessionId, out _))
throw new Exception($"Session with id {sessionId} already exists");
// Create context
var (model, context) = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId);
// Create session
var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig);
if (!_modelSessions.TryAdd(sessionId, modelSession))
throw new Exception($"Failed to create model session");
// Run initial Prompt
await modelSession.InitializePrompt(inferenceConfig, cancellationToken);
return modelSession;
}
///
/// Closes the session
///
/// The session identifier.
///
public async Task CloseAsync(string sessionId)
{
if (_modelSessions.TryRemove(sessionId, out var modelSession))
{
modelSession.CancelInfer();
return await _modelService.RemoveContext(modelSession.ModelName, sessionId);
}
return false;
}
///
/// Runs inference on the current ModelSession
///
/// The session identifier.
/// The prompt.
/// The inference configuration, if null session default is used
/// The cancellation token.
/// Inference is already running for this session
public async IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (!_sessionGuard.Guard(sessionId))
throw new Exception($"Inference is already running for this session");
try
{
if (!_modelSessions.TryGetValue(sessionId, out var modelSession))
yield break;
// Send begin of response
var stopwatch = Stopwatch.GetTimestamp();
yield return new TokenModel(default, default, TokenType.Begin);
// Send content of response
await foreach (var token in modelSession.InferAsync(prompt, inferenceConfig, cancellationToken).ConfigureAwait(false))
{
yield return new TokenModel(default, token);
}
// Send end of response
var elapsedTime = GetElapsed(stopwatch);
var endTokenType = modelSession.IsInferCanceled() ? TokenType.Cancel : TokenType.End;
var signature = endTokenType == TokenType.Cancel
? $"Inference cancelled after {elapsedTime / 1000:F0} seconds"
: $"Inference completed in {elapsedTime / 1000:F0} seconds";
yield return new TokenModel(default, signature, endTokenType);
}
finally
{
_sessionGuard.Release(sessionId);
}
}
///
/// Runs inference on the current ModelSession
///
/// The session identifier.
/// The prompt.
/// The inference configuration, if null session default is used
/// The cancellation token.
/// Streaming async result of
/// Inference is already running for this session
public IAsyncEnumerable InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
async IAsyncEnumerable InferTextInternal()
{
await foreach (var token in InferAsync(sessionId, prompt, inferenceConfig, cancellationToken).ConfigureAwait(false))
{
if (token.TokenType == TokenType.Content)
yield return token.Content;
}
}
return InferTextInternal();
}
///
/// Runs inference on the current ModelSession
///
/// The session identifier.
/// The prompt.
/// The inference configuration, if null session default is used
/// The cancellation token.
/// Completed inference result as string
/// Inference is already running for this session
public async Task InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
var inferResult = await InferAsync(sessionId, prompt, inferenceConfig, cancellationToken)
.Where(x => x.TokenType == TokenType.Content)
.Select(x => x.Content)
.ToListAsync(cancellationToken: cancellationToken);
return string.Concat(inferResult);
}
///
/// Cancels the current inference action.
///
/// The session identifier.
///
public Task CancelAsync(string sessionId)
{
if (_modelSessions.TryGetValue(sessionId, out var modelSession))
{
modelSession.CancelInfer();
return Task.FromResult(true);
}
return Task.FromResult(false);
}
///
/// Gets the elapsed time in milliseconds.
///
/// The timestamp.
///
private static int GetElapsed(long timestamp)
{
return (int)Stopwatch.GetElapsedTime(timestamp).TotalMilliseconds;
}
}
}