You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelSessionService.cs 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. using LLama.Web.Async;
  2. using LLama.Web.Common;
  3. using LLama.Web.Models;
  4. using System.Collections.Concurrent;
  5. using System.Diagnostics;
  6. using System.Runtime.CompilerServices;
  7. namespace LLama.Web.Services
  8. {
  9. /// <summary>
  10. /// Example Service for handling a model session for a websockets connection lifetime
  11. /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc
  12. /// </summary>
  13. public class ModelSessionService : IModelSessionService
  14. {
  15. private readonly AsyncGuard<string> _sessionGuard;
  16. private readonly IModelService _modelService;
  17. private readonly ConcurrentDictionary<string, ModelSession> _modelSessions;
  18. /// <summary>
  19. /// Initializes a new instance of the <see cref="ModelSessionService{T}"/> class.
  20. /// </summary>
  21. /// <param name="modelService">The model service.</param>
  22. /// <param name="modelSessionStateService">The model session state service.</param>
  23. public ModelSessionService(IModelService modelService)
  24. {
  25. _modelService = modelService;
  26. _sessionGuard = new AsyncGuard<string>();
  27. _modelSessions = new ConcurrentDictionary<string, ModelSession>();
  28. }
  29. /// <summary>
  30. /// Gets the ModelSession with the specified Id.
  31. /// </summary>
  32. /// <param name="sessionId">The session identifier.</param>
  33. /// <returns>The ModelSession if exists, otherwise null</returns>
  34. public Task<ModelSession> GetAsync(string sessionId)
  35. {
  36. return Task.FromResult(_modelSessions.TryGetValue(sessionId, out var session) ? session : null);
  37. }
  38. /// <summary>
  39. /// Gets all ModelSessions
  40. /// </summary>
  41. /// <returns>A collection oa all Model instances</returns>
  42. public Task<IEnumerable<ModelSession>> GetAllAsync()
  43. {
  44. return Task.FromResult<IEnumerable<ModelSession>>(_modelSessions.Values);
  45. }
  46. /// <summary>
  47. /// Creates a new ModelSession
  48. /// </summary>
  49. /// <param name="sessionId">The session identifier.</param>
  50. /// <param name="sessionConfig">The session configuration.</param>
  51. /// <param name="inferenceConfig">The default inference configuration, will be used for all inference where no infer configuration is supplied.</param>
  52. /// <param name="cancellationToken">The cancellation token.</param>
  53. /// <returns></returns>
  54. /// <exception cref="System.Exception">
  55. /// Session with id {sessionId} already exists
  56. /// or
  57. /// Failed to create model session
  58. /// </exception>
  59. public async Task<ModelSession> CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
  60. {
  61. if (_modelSessions.TryGetValue(sessionId, out _))
  62. throw new Exception($"Session with id {sessionId} already exists");
  63. // Create context
  64. var (model, context) = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId);
  65. // Create session
  66. var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig);
  67. if (!_modelSessions.TryAdd(sessionId, modelSession))
  68. throw new Exception($"Failed to create model session");
  69. // Run initial Prompt
  70. await modelSession.InitializePrompt(inferenceConfig, cancellationToken);
  71. return modelSession;
  72. }
  73. /// <summary>
  74. /// Closes the session
  75. /// </summary>
  76. /// <param name="sessionId">The session identifier.</param>
  77. /// <returns></returns>
  78. public async Task<bool> CloseAsync(string sessionId)
  79. {
  80. if (_modelSessions.TryRemove(sessionId, out var modelSession))
  81. {
  82. modelSession.CancelInfer();
  83. return await _modelService.RemoveContext(modelSession.ModelName, sessionId);
  84. }
  85. return false;
  86. }
  87. /// <summary>
  88. /// Runs inference on the current ModelSession
  89. /// </summary>
  90. /// <param name="sessionId">The session identifier.</param>
  91. /// <param name="prompt">The prompt.</param>
  92. /// <param name="inferenceConfig">The inference configuration, if null session default is used</param>
  93. /// <param name="cancellationToken">The cancellation token.</param>
  94. /// <exception cref="System.Exception">Inference is already running for this session</exception>
  95. public async IAsyncEnumerable<TokenModel> InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
  96. {
  97. if (!_sessionGuard.Guard(sessionId))
  98. throw new Exception($"Inference is already running for this session");
  99. try
  100. {
  101. if (!_modelSessions.TryGetValue(sessionId, out var modelSession))
  102. yield break;
  103. // Send begin of response
  104. var stopwatch = Stopwatch.GetTimestamp();
  105. yield return new TokenModel(default, default, TokenType.Begin);
  106. // Send content of response
  107. await foreach (var token in modelSession.InferAsync(prompt, inferenceConfig, cancellationToken).ConfigureAwait(false))
  108. {
  109. yield return new TokenModel(default, token);
  110. }
  111. // Send end of response
  112. var elapsedTime = GetElapsed(stopwatch);
  113. var endTokenType = modelSession.IsInferCanceled() ? TokenType.Cancel : TokenType.End;
  114. var signature = endTokenType == TokenType.Cancel
  115. ? $"Inference cancelled after {elapsedTime / 1000:F0} seconds"
  116. : $"Inference completed in {elapsedTime / 1000:F0} seconds";
  117. yield return new TokenModel(default, signature, endTokenType);
  118. }
  119. finally
  120. {
  121. _sessionGuard.Release(sessionId);
  122. }
  123. }
  124. /// <summary>
  125. /// Runs inference on the current ModelSession
  126. /// </summary>
  127. /// <param name="sessionId">The session identifier.</param>
  128. /// <param name="prompt">The prompt.</param>
  129. /// <param name="inferenceConfig">The inference configuration, if null session default is used</param>
  130. /// <param name="cancellationToken">The cancellation token.</param>
  131. /// <returns>Streaming async result of <see cref="System.String" /></returns>
  132. /// <exception cref="System.Exception">Inference is already running for this session</exception>
  133. public IAsyncEnumerable<string> InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
  134. {
  135. async IAsyncEnumerable<string> InferTextInternal()
  136. {
  137. await foreach (var token in InferAsync(sessionId, prompt, inferenceConfig, cancellationToken).ConfigureAwait(false))
  138. {
  139. if (token.TokenType == TokenType.Content)
  140. yield return token.Content;
  141. }
  142. }
  143. return InferTextInternal();
  144. }
  145. /// <summary>
  146. /// Runs inference on the current ModelSession
  147. /// </summary>
  148. /// <param name="sessionId">The session identifier.</param>
  149. /// <param name="prompt">The prompt.</param>
  150. /// <param name="inferenceConfig">The inference configuration, if null session default is used</param>
  151. /// <param name="cancellationToken">The cancellation token.</param>
  152. /// <returns>Completed inference result as string</returns>
  153. /// <exception cref="System.Exception">Inference is already running for this session</exception>
  154. public async Task<string> InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
  155. {
  156. var inferResult = await InferAsync(sessionId, prompt, inferenceConfig, cancellationToken)
  157. .Where(x => x.TokenType == TokenType.Content)
  158. .Select(x => x.Content)
  159. .ToListAsync(cancellationToken: cancellationToken);
  160. return string.Concat(inferResult);
  161. }
  162. /// <summary>
  163. /// Cancels the current inference action.
  164. /// </summary>
  165. /// <param name="sessionId">The session identifier.</param>
  166. /// <returns></returns>
  167. public Task<bool> CancelAsync(string sessionId)
  168. {
  169. if (_modelSessions.TryGetValue(sessionId, out var modelSession))
  170. {
  171. modelSession.CancelInfer();
  172. return Task.FromResult(true);
  173. }
  174. return Task.FromResult(false);
  175. }
  176. /// <summary>
  177. /// Gets the elapsed time in milliseconds.
  178. /// </summary>
  179. /// <param name="timestamp">The timestamp.</param>
  180. /// <returns></returns>
  181. private static int GetElapsed(long timestamp)
  182. {
  183. return (int)Stopwatch.GetElapsedTime(timestamp).TotalMilliseconds;
  184. }
  185. }
  186. }