using LLama.Web.Async; using LLama.Web.Common; using LLama.Web.Models; using System.Collections.Concurrent; using System.Diagnostics; using System.Runtime.CompilerServices; namespace LLama.Web.Services { /// /// Example Service for handling a model session for a websockets connection lifetime /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc /// public class ModelSessionService : IModelSessionService { private readonly AsyncGuard _sessionGuard; private readonly IModelService _modelService; private readonly ConcurrentDictionary _modelSessions; /// /// Initializes a new instance of the class. /// /// The model service. /// The model session state service. public ModelSessionService(IModelService modelService) { _modelService = modelService; _sessionGuard = new AsyncGuard(); _modelSessions = new ConcurrentDictionary(); } /// /// Gets the ModelSession with the specified Id. /// /// The session identifier. /// The ModelSession if exists, otherwise null public Task GetAsync(string sessionId) { return Task.FromResult(_modelSessions.TryGetValue(sessionId, out var session) ? session : null); } /// /// Gets all ModelSessions /// /// A collection oa all Model instances public Task> GetAllAsync() { return Task.FromResult>(_modelSessions.Values); } /// /// Creates a new ModelSession /// /// The session identifier. /// The session configuration. /// The default inference configuration, will be used for all inference where no infer configuration is supplied. /// The cancellation token. /// /// /// Session with id {sessionId} already exists /// or /// Failed to create model session /// public async Task CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { if (_modelSessions.TryGetValue(sessionId, out _)) throw new Exception($"Session with id {sessionId} already exists"); // Create context var (model, context) = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId); // Create session var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig); if (!_modelSessions.TryAdd(sessionId, modelSession)) throw new Exception($"Failed to create model session"); // Run initial Prompt await modelSession.InitializePrompt(inferenceConfig, cancellationToken); return modelSession; } /// /// Closes the session /// /// The session identifier. /// public async Task CloseAsync(string sessionId) { if (_modelSessions.TryRemove(sessionId, out var modelSession)) { modelSession.CancelInfer(); return await _modelService.RemoveContext(modelSession.ModelName, sessionId); } return false; } /// /// Runs inference on the current ModelSession /// /// The session identifier. /// The prompt. /// The inference configuration, if null session default is used /// The cancellation token. /// Inference is already running for this session public async IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (!_sessionGuard.Guard(sessionId)) throw new Exception($"Inference is already running for this session"); try { if (!_modelSessions.TryGetValue(sessionId, out var modelSession)) yield break; // Send begin of response var stopwatch = Stopwatch.GetTimestamp(); yield return new TokenModel(default, default, TokenType.Begin); // Send content of response await foreach (var token in modelSession.InferAsync(prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) { yield return new TokenModel(default, token); } // Send end of response var elapsedTime = GetElapsed(stopwatch); var endTokenType = modelSession.IsInferCanceled() ? TokenType.Cancel : TokenType.End; var signature = endTokenType == TokenType.Cancel ? $"Inference cancelled after {elapsedTime / 1000:F0} seconds" : $"Inference completed in {elapsedTime / 1000:F0} seconds"; yield return new TokenModel(default, signature, endTokenType); } finally { _sessionGuard.Release(sessionId); } } /// /// Runs inference on the current ModelSession /// /// The session identifier. /// The prompt. /// The inference configuration, if null session default is used /// The cancellation token. /// Streaming async result of /// Inference is already running for this session public IAsyncEnumerable InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { async IAsyncEnumerable InferTextInternal() { await foreach (var token in InferAsync(sessionId, prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) { if (token.TokenType == TokenType.Content) yield return token.Content; } } return InferTextInternal(); } /// /// Runs inference on the current ModelSession /// /// The session identifier. /// The prompt. /// The inference configuration, if null session default is used /// The cancellation token. /// Completed inference result as string /// Inference is already running for this session public async Task InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { var inferResult = await InferAsync(sessionId, prompt, inferenceConfig, cancellationToken) .Where(x => x.TokenType == TokenType.Content) .Select(x => x.Content) .ToListAsync(cancellationToken: cancellationToken); return string.Concat(inferResult); } /// /// Cancels the current inference action. /// /// The session identifier. /// public Task CancelAsync(string sessionId) { if (_modelSessions.TryGetValue(sessionId, out var modelSession)) { modelSession.CancelInfer(); return Task.FromResult(true); } return Task.FromResult(false); } /// /// Gets the elapsed time in milliseconds. /// /// The timestamp. /// private static int GetElapsed(long timestamp) { return (int)Stopwatch.GetElapsedTime(timestamp).TotalMilliseconds; } } }