You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConnectionSessionService.cs 4.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. using LLama.Abstractions;
  2. using LLama.Web.Common;
  3. using LLama.Web.Models;
  4. using Microsoft.Extensions.Options;
  5. using System.Collections.Concurrent;
  6. namespace LLama.Web.Services
  7. {
  8. /// <summary>
  9. /// Example Service for handling a model session for a websockets connection lifetime
  10. /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc
  11. /// </summary>
  12. public class ConnectionSessionService : IModelSessionService
  13. {
  14. private readonly LLamaOptions _options;
  15. private readonly ILogger<ConnectionSessionService> _logger;
  16. private readonly ConcurrentDictionary<string, ModelSession> _modelSessions;
  17. public ConnectionSessionService(ILogger<ConnectionSessionService> logger, IOptions<LLamaOptions> options)
  18. {
  19. _logger = logger;
  20. _options = options.Value;
  21. _modelSessions = new ConcurrentDictionary<string, ModelSession>();
  22. }
  23. public Task<ModelSession> GetAsync(string connectionId)
  24. {
  25. _modelSessions.TryGetValue(connectionId, out var modelSession);
  26. return Task.FromResult(modelSession);
  27. }
  28. public Task<IServiceResult<ModelSession>> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName)
  29. {
  30. var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName);
  31. if (modelOption is null)
  32. return Task.FromResult(ServiceResult.FromError<ModelSession>($"Model option '{modelName}' not found"));
  33. var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName);
  34. if (promptOption is null)
  35. return Task.FromResult(ServiceResult.FromError<ModelSession>($"Prompt option '{promptName}' not found"));
  36. var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName);
  37. if (parameterOption is null)
  38. return Task.FromResult(ServiceResult.FromError<ModelSession>($"Parameter option '{parameterName}' not found"));
  39. //Max instance
  40. var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name);
  41. if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances)
  42. return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached"));
  43. // Load weights
  44. // todo: it would be better to have a central service which loads weights and shares them between all contexts that need them!
  45. using var weights = LLamaWeights.LoadFromFile(modelOption);
  46. // Create executor
  47. ILLamaExecutor executor = executorType switch
  48. {
  49. LLamaExecutorType.Interactive => new InteractiveExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
  50. LLamaExecutorType.Instruct => new InstructExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
  51. LLamaExecutorType.Stateless => new StatelessExecutor(weights, modelOption),
  52. _ => default
  53. };
  54. // Create session
  55. var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption);
  56. if (!_modelSessions.TryAdd(connectionId, modelSession))
  57. return Task.FromResult(ServiceResult.FromError<ModelSession>("Failed to create model session"));
  58. return Task.FromResult(ServiceResult.FromValue(modelSession));
  59. }
  60. public Task<bool> RemoveAsync(string connectionId)
  61. {
  62. if (_modelSessions.TryRemove(connectionId, out var modelSession))
  63. {
  64. modelSession.CancelInfer();
  65. modelSession.Dispose();
  66. return Task.FromResult(true);
  67. }
  68. return Task.FromResult(false);
  69. }
  70. public Task<bool> CancelAsync(string connectionId)
  71. {
  72. if (_modelSessions.TryGetValue(connectionId, out var modelSession))
  73. {
  74. modelSession.CancelInfer();
  75. return Task.FromResult(true);
  76. }
  77. return Task.FromResult(false);
  78. }
  79. }
  80. }