|
- using LLama.Abstractions;
- using LLama.Web.Common;
- using LLama.Web.Models;
- using Microsoft.Extensions.Options;
- using System.Collections.Concurrent;
-
- namespace LLama.Web.Services
- {
- /// <summary>
- /// Example Service for handling a model session for a websockets connection lifetime
- /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc
- /// </summary>
- public class ConnectionSessionService : IModelSessionService
- {
- private readonly LLamaOptions _options;
- private readonly ILogger<ConnectionSessionService> _logger;
- private readonly ConcurrentDictionary<string, ModelSession> _modelSessions;
-
- public ConnectionSessionService(ILogger<ConnectionSessionService> logger, IOptions<LLamaOptions> options)
- {
- _logger = logger;
- _options = options.Value;
- _modelSessions = new ConcurrentDictionary<string, ModelSession>();
- }
-
- public Task<ModelSession> GetAsync(string connectionId)
- {
- _modelSessions.TryGetValue(connectionId, out var modelSession);
- return Task.FromResult(modelSession);
- }
-
- public Task<IServiceResult<ModelSession>> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName)
- {
- var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName);
- if (modelOption is null)
- return Task.FromResult(ServiceResult.FromError<ModelSession>($"Model option '{modelName}' not found"));
-
- var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName);
- if (promptOption is null)
- return Task.FromResult(ServiceResult.FromError<ModelSession>($"Prompt option '{promptName}' not found"));
-
- var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName);
- if (parameterOption is null)
- return Task.FromResult(ServiceResult.FromError<ModelSession>($"Parameter option '{parameterName}' not found"));
-
-
- //Max instance
- var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name);
- if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances)
- return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached"));
-
- // Load weights
- // todo: it would be better to have a central service which loads weights and shares them between all contexts that need them!
- using var weights = LLamaWeights.LoadFromFile(modelOption);
-
- // Create executor
- ILLamaExecutor executor = executorType switch
- {
- LLamaExecutorType.Interactive => new InteractiveExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
- LLamaExecutorType.Instruct => new InstructExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
- LLamaExecutorType.Stateless => new StatelessExecutor(weights, modelOption),
- _ => default
- };
-
- // Create session
- var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption);
- if (!_modelSessions.TryAdd(connectionId, modelSession))
- return Task.FromResult(ServiceResult.FromError<ModelSession>("Failed to create model session"));
-
- return Task.FromResult(ServiceResult.FromValue(modelSession));
- }
-
- public Task<bool> RemoveAsync(string connectionId)
- {
- if (_modelSessions.TryRemove(connectionId, out var modelSession))
- {
- modelSession.CancelInfer();
- modelSession.Dispose();
- return Task.FromResult(true);
- }
- return Task.FromResult(false);
- }
-
- public Task<bool> CancelAsync(string connectionId)
- {
- if (_modelSessions.TryGetValue(connectionId, out var modelSession))
- {
- modelSession.CancelInfer();
- return Task.FromResult(true);
- }
- return Task.FromResult(false);
- }
- }
- }
|