You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelSessionService.cs 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. using LLama.Abstractions;
  2. using LLama.Web.Common;
  3. using LLama.Web.Models;
  4. using System.Collections.Concurrent;
  5. namespace LLama.Web.Services
  6. {
  7. public class ModelSessionService : IModelSessionService
  8. {
  9. private readonly ILogger<ModelSessionService> _logger;
  10. private readonly ConcurrentDictionary<string, ModelSession> _modelSessions;
  11. public ModelSessionService(ILogger<ModelSessionService> logger)
  12. {
  13. _logger = logger;
  14. _modelSessions = new ConcurrentDictionary<string, ModelSession>();
  15. }
  16. public Task<ModelSession> GetAsync(string connectionId)
  17. {
  18. _modelSessions.TryGetValue(connectionId, out var modelSession);
  19. return Task.FromResult(modelSession);
  20. }
  21. public Task<ModelSession> CreateAsync(string connectionId, ILLamaExecutor executor, ModelOptions modelOption, PromptOptions promptOption, ParameterOptions parameterOption)
  22. {
  23. //TODO: Max instance etc
  24. var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption);
  25. if (!_modelSessions.TryAdd(connectionId, modelSession))
  26. {
  27. _logger.Log(LogLevel.Error, "[CreateAsync] - Failed to create model session, Connection: {0}", connectionId);
  28. return Task.FromResult<ModelSession>(default);
  29. }
  30. return Task.FromResult(modelSession);
  31. }
  32. public Task RemoveAsync(string connectionId)
  33. {
  34. if (_modelSessions.TryRemove(connectionId, out var modelSession))
  35. {
  36. _logger.Log(LogLevel.Information, "[RemoveAsync] - Removed model session, Connection: {0}", connectionId);
  37. modelSession.Dispose();
  38. }
  39. return Task.CompletedTask;
  40. }
  41. public Task CancelAsync(string connectionId)
  42. {
  43. if (_modelSessions.TryGetValue(connectionId, out var modelSession))
  44. {
  45. _logger.Log(LogLevel.Information, "[CancelAsync] - Canceled model session, Connection: {0}", connectionId);
  46. modelSession.CancelInfer();
  47. }
  48. return Task.CompletedTask;
  49. }
  50. }
  51. }